Add streaming safetensors loading with disk weight cache

This commit is contained in:
ifilipis 2026-01-08 13:02:08 +02:00
parent 3cd7b32f1b
commit f925f8fa77
24 changed files with 1640 additions and 216 deletions

57
DESIGN.md Normal file
View File

@ -0,0 +1,57 @@
# Disk tier safetensors streaming design audit (ComfyUI)
## Mandatory research audit (verified call sites)
### ComfyUI load path + eager materialization sites
- `comfy/utils.py:load_torch_file` currently uses `safetensors.safe_open` and iterates all keys to build a full `sd` dict (eager tensor materialization). It also returns metadata only after reading all tensors.【F:comfy/utils.py†L58-L93】
- `comfy/utils.py:calculate_parameters` and `weight_dtype` iterate `sd.keys()` and then access `sd[k]` to compute `nelement()`/`dtype` (loads tensors).【F:comfy/utils.py†L109-L128】
- `comfy/utils.py:state_dict_prefix_replace` mutates dicts by `pop`+assignment (materializes if used on a streaming mapping).【F:comfy/utils.py†L135-L144】
- `comfy/model_base.py:BaseModel.load_model_weights` builds `to_load = {}` by iterating keys and popping tensors, then passes a fully materialized dict to `load_state_dict` (RAM spike).【F:comfy/model_base.py†L301-L318】
- `comfy/model_detection.py` reads `state_dict[key].shape` in many branches for detection (must be metadata-only). Example: `calculate_transformer_depth` and numerous `detect_unet_config` branches read shapes directly from `state_dict` values.【F:comfy/model_detection.py†L21-L200】
- `comfy/sd.py` loads checkpoints, then slices, renames, and computes parameters/dtypes by reading tensors (e.g., `calculate_parameters`, `weight_dtype`, `process_*_state_dict`, and special scaled-FP8 conversion that builds new dicts).【F:comfy/sd.py†L1304-L1519】
- Direct safetensors load outside `load_torch_file`: `comfy/sd1_clip.py:load_embed` and `nodes.py:LoadLatent.load` use `safetensors.torch.load_file`, bypassing the core loader.【F:comfy/sd1_clip.py†L432-L434】【F:nodes.py†L521-L529】
### FastSageTensors (fastsafetensors) capability audit
- Header parsing and metadata:
- `fastsafetensors/common.py:SafeTensorsMetadata` parses the header and builds per-tensor `TensorFrame` with `dtype`, `shape`, and `data_offsets` (no tensor allocation).【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L63-L187】
- `TensorFrame` stores dtype/shape/offsets and supports slicing metadata.【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L238-L338】
- GDS + no-GDS low-level readers:
- `fastsafetensors/cpp.pyi` exposes `gds_file_reader`, `gds_file_handle`, `nogds_file_reader`, `cpu_malloc`, `gpu_malloc`, and alignment helpers such as `get_alignment_size()`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L1-L43】
- GDS availability checks are in `fastsafetensors/cpp.pyi`: `is_gds_supported`, `is_cufile_found`, `cufile_version`, and `init_gds`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L36-L43】
- DLPack wrapping:
- `fastsafetensors/dlpack.py` provides `from_cuda_buffer()` which creates DLPack capsules for both CPU and GPU buffers via a device descriptor and is used for `torch.from_dlpack`.【F:../third_party/fastsafetensors-main/fastsafetensors/dlpack.py†L232-L239】
- Torch framework interop:
- `fastsafetensors/frameworks/_torch.py:TorchOp` provides `alloc_tensor_memory`/`free_tensor_memory`, dtype mapping, and uses `torch.from_dlpack` for wrapping raw pointers into tensors.【F:../third_party/fastsafetensors-main/fastsafetensors/frameworks/_torch.py†L131-L205】
### VRAM/RAM offload logic (for extension)
- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】
- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】
## Strategy summary (no coding performed yet)
### Streaming safetensors mapping (no full dict materialization)
- Introduce a new module `comfy/safetensors_stream.py` with:
- `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
- `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
- Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
### Range reads and tiering
- Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
- Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
- Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
### Disk tier integration
- Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
- Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
- Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
### Pipeline refactors
- Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
- Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
- Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
- Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
- Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
### Tests and docs
- Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
- Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.

View File

@ -349,6 +349,14 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
### Disk tier for model weights
When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.
# Running

View File

@ -29,7 +29,7 @@ class AudioEncoderModel():
self.model_sample_rate = 16000
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return comfy.utils.load_state_dict(self.model, sd, strict=False)
def get_sd(self):
return self.model.state_dict()

View File

@ -114,6 +114,9 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")

View File

@ -48,7 +48,7 @@ class ClipVisionModel():
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return comfy.utils.load_state_dict(self.model, sd, strict=False)
def get_sd(self):
return self.model.state_dict()

View File

@ -439,7 +439,7 @@ def controlnet_config(sd, model_options={}):
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@ -473,9 +473,9 @@ def load_controlnet_mmdit(sd, model_options={}):
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
@ -748,9 +748,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
pass
w = WeightsLoader()
w.control_model = control_model
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
else:
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@ -874,7 +874,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
else:
return None
missing, unexpected = model_ad.load_state_dict(t2i_data)
missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
if len(missing) > 0:
logging.warning("t2i missing {}".format(missing))

275
comfy/disk_weights.py Normal file
View File

@ -0,0 +1,275 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import collections
import weakref
from dataclasses import dataclass
from typing import Dict, Optional
import torch
ALLOW_GDS = False
PIN_IF_CPU = False
DISK_WEIGHTS_ENABLED = False
@dataclass
class DiskTensorRef:
state_dict: object
key: str
meta: object
requires_grad: bool
is_buffer: bool
def load(self, device: torch.device, allow_gds: bool, pin_if_cpu: bool) -> torch.Tensor:
dtype = getattr(self.meta, "dtype", None)
if hasattr(self.state_dict, "get_tensor"):
return self.state_dict.get_tensor(
self.key,
device=device,
dtype=dtype,
allow_gds=allow_gds,
pin_if_cpu=pin_if_cpu,
)
tensor = self.state_dict[self.key]
if device is not None and tensor.device != device:
tensor = tensor.to(device=device)
if dtype is not None and tensor.dtype != dtype:
tensor = tensor.to(dtype=dtype)
return tensor
class DiskWeightRegistry:
def __init__(self):
self._registry = weakref.WeakKeyDictionary()
def register(self, module: torch.nn.Module, name: str, ref: DiskTensorRef):
module_refs = self._registry.setdefault(module, {})
module_refs[name] = ref
def get(self, module: torch.nn.Module) -> Optional[Dict[str, DiskTensorRef]]:
return self._registry.get(module)
def has(self, module: torch.nn.Module) -> bool:
return module in self._registry
@dataclass
class CacheEntry:
module_ref: weakref.ReferenceType
name: str
size_bytes: int
is_buffer: bool
class DiskWeightCache:
def __init__(self, max_bytes: int = 0):
self.max_bytes = max_bytes
self.current_bytes = 0
self._entries: "collections.OrderedDict[tuple[int, str], CacheEntry]" = collections.OrderedDict()
def set_limit(self, max_bytes: int):
self.max_bytes = max_bytes
self._evict_if_needed()
def _entry_key(self, module: torch.nn.Module, name: str) -> tuple[int, str]:
return (id(module), name)
def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool):
if tensor.device.type != "cpu":
return
size_bytes = tensor.numel() * tensor.element_size()
key = self._entry_key(module, name)
if key in self._entries:
entry = self._entries.pop(key)
self.current_bytes -= entry.size_bytes
module_ref = weakref.ref(module, self._drop_module_entries)
self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer)
self.current_bytes += size_bytes
self._evict_if_needed()
def touch(self, module: torch.nn.Module, name: str):
key = self._entry_key(module, name)
if key in self._entries:
entry = self._entries.pop(key)
self._entries[key] = entry
def evict_bytes(self, bytes_to_free: int):
freed = 0
while self._entries and freed < bytes_to_free:
_, entry = self._entries.popitem(last=False)
freed += entry.size_bytes
self.current_bytes -= entry.size_bytes
module = entry.module_ref()
if module is not None:
_evict_module_weight(module, entry.name, entry.is_buffer)
return freed
def _drop_module_entries(self, module_ref: weakref.ReferenceType):
to_remove = []
for key, entry in self._entries.items():
if entry.module_ref is module_ref:
to_remove.append(key)
for key in to_remove:
entry = self._entries.pop(key)
self.current_bytes -= entry.size_bytes
def _evict_if_needed(self):
while self._entries and self.current_bytes > self.max_bytes:
_, entry = self._entries.popitem(last=False)
self.current_bytes -= entry.size_bytes
module = entry.module_ref()
if module is not None:
_evict_module_weight(module, entry.name, entry.is_buffer)
REGISTRY = DiskWeightRegistry()
CACHE = DiskWeightCache(0)
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED
ALLOW_GDS = allow_gds
PIN_IF_CPU = pin_if_cpu
DISK_WEIGHTS_ENABLED = enabled
CACHE.set_limit(cache_bytes if enabled else 0)
if not enabled:
CACHE._entries.clear()
CACHE.current_bytes = 0
def disk_weights_enabled() -> bool:
return DISK_WEIGHTS_ENABLED
def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
if not disk_weights_enabled():
return
if not hasattr(state_dict, "meta") or not hasattr(state_dict, "get_tensor"):
return
for name, param in module.named_parameters(recurse=True):
key = f"{prefix}{name}" if prefix else name
if key in state_dict:
meta = state_dict.meta(key)
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
REGISTRY.register(module, name, ref)
if param.device.type == "cpu":
CACHE.record(module, name, param, is_buffer=False)
for name, buf in module.named_buffers(recurse=True):
key = f"{prefix}{name}" if prefix else name
if key in state_dict and buf is not None:
meta = state_dict.meta(key)
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
REGISTRY.register(module, name, ref)
if buf.device.type == "cpu":
CACHE.record(module, name, buf, is_buffer=True)
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
ref = REGISTRY.get(module)
if not ref or name not in ref:
return
disk_ref = ref[name]
shape = getattr(disk_ref.meta, "shape", None)
dtype = getattr(disk_ref.meta, "dtype", None)
if shape is None or dtype is None:
return
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
if is_buffer:
module._buffers[name] = meta_tensor
else:
module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
def check(obj):
if torch.is_tensor(obj):
return obj.device
if isinstance(obj, (list, tuple)):
for item in obj:
dev = check(item)
if dev is not None:
return dev
if isinstance(obj, dict):
for item in obj.values():
dev = check(item)
if dev is not None:
return dev
return None
dev = check(args)
if dev is not None:
return dev
return check(kwargs)
def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device):
refs = REGISTRY.get(module)
if not refs:
return
for name, disk_ref in refs.items():
if name in module._parameters:
current = module._parameters[name]
is_buffer = False
elif name in module._buffers:
current = module._buffers[name]
is_buffer = True
else:
continue
if current is None:
continue
if current.device.type != "meta":
if current.device.type == "cpu":
CACHE.touch(module, name)
continue
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
if is_buffer:
module._buffers[name] = tensor
else:
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
if tensor.device.type == "cpu":
CACHE.record(module, name, tensor, is_buffer=is_buffer)
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
if not REGISTRY.has(module):
return
if getattr(module, "comfy_cast_weights", False):
target_device = torch.device("cpu")
else:
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
ensure_module_materialized(module, target_device)
def attach_disk_weight_hooks(model: torch.nn.Module):
if not disk_weights_enabled():
return
for module in model.modules():
if getattr(module, "_disk_weight_hook_attached", False):
continue
module.register_forward_pre_hook(disk_weight_pre_hook)
module._disk_weight_hook_attached = True
def evict_ram_cache(bytes_to_free: int):
if bytes_to_free <= 0:
return 0
return CACHE.evict_bytes(bytes_to_free)

View File

@ -3,6 +3,7 @@ import torch
from torch import nn
from .ldm.modules.attention import CrossAttention, FeedForward
import comfy.ops
import comfy.utils
ops = comfy.ops.manual_cast
@ -282,7 +283,7 @@ def load_gligen(sd):
gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
gated.load_state_dict(n_sd, strict=False)
comfy.utils.load_state_dict(gated, n_sd, strict=False)
output_list.append(gated)
if "position_net.null_positive_feature" in sd_k:
@ -293,7 +294,7 @@ def load_gligen(sd):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
w.load_state_dict(sd, strict=False)
comfy.utils.load_state_dict(w, sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen

View File

@ -1,4 +1,5 @@
import torch
import comfy.utils
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
@ -112,7 +113,7 @@ class HunyuanVideo15SRModel():
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
return comfy.utils.load_state_dict(self.model, sd, strict=True)
def get_sd(self):
return self.model.state_dict()

View File

@ -2,6 +2,7 @@ import json
from dataclasses import dataclass
import math
import torch
import comfy.utils
import torchaudio
import comfy.model_management
@ -153,8 +154,8 @@ class AudioVAE(torch.nn.Module):
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(

View File

@ -2,6 +2,7 @@ import logging
from typing import Optional
import torch
import comfy.utils
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
@ -152,7 +153,7 @@ class VAE(nn.Module):
return dec, posterior
def load_weights(self, src_dict) -> None:
self.load_state_dict(src_dict, strict=True)
comfy.utils.load_state_dict(self, src_dict, strict=True)
@property
def device(self) -> torch.device:
@ -355,4 +356,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')

View File

@ -299,20 +299,14 @@ class BaseModel(torch.nn.Module):
return out
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
keys = list(sd.keys())
for k in keys:
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = utils.state_dict_prefix_replace(sd, {unet_prefix: ""}, filter_keys=True)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
if len(u) > 0:
logging.warning("unet unexpected: {}".format(u))
del to_load
return self
def process_latent_in(self, latent):
@ -751,8 +745,8 @@ class StableAudio1(BaseModel):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
def extra_conds(self, **kwargs):
out = {}

View File

@ -19,6 +19,9 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def sd_shape(state_dict, key):
return comfy.utils.state_dict_meta(state_dict, key).shape
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@ -27,8 +30,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
context_dim = sd_shape(state_dict, '{}0.attn2.to_k.weight'.format(transformer_prefix))[1]
use_linear_in_transformer = len(sd_shape(state_dict, '{}1.proj_in.weight'.format(prefix))) == 2
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
@ -39,27 +42,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
unet_config = {}
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
unet_config["in_channels"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[1]
patch_size = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[2]
unet_config["patch_size"] = patch_size
final_layer = '{}final_layer.linear.weight'.format(key_prefix)
if final_layer in state_dict:
unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
unet_config["out_channels"] = sd_shape(state_dict, final_layer)[0] // (patch_size * patch_size)
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
unet_config["depth"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0] // 64
unet_config["input_size"] = None
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
if y_key in state_dict_keys:
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
unet_config["adm_in_channels"] = sd_shape(state_dict, y_key)[1]
context_key = '{}context_embedder.weight'.format(key_prefix)
if context_key in state_dict_keys:
in_features = state_dict[context_key].shape[1]
out_features = state_dict[context_key].shape[0]
in_features = sd_shape(state_dict, context_key)[1]
out_features = sd_shape(state_dict, context_key)[0]
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
num_patches_key = '{}pos_embed'.format(key_prefix)
if num_patches_key in state_dict_keys:
num_patches = state_dict[num_patches_key].shape[1]
num_patches = sd_shape(state_dict, num_patches_key)[1]
unet_config["num_patches"] = num_patches
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
@ -83,23 +86,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
if text_mapper_name in state_dict_keys:
unet_config['stable_cascade_stage'] = 'c'
w = state_dict[text_mapper_name]
if w.shape[0] == 1536: #stage c lite
w_shape = sd_shape(state_dict, text_mapper_name)
if w_shape[0] == 1536: #stage c lite
unet_config['c_cond'] = 1536
unet_config['c_hidden'] = [1536, 1536]
unet_config['nhead'] = [24, 24]
unet_config['blocks'] = [[4, 12], [12, 4]]
elif w.shape[0] == 2048: #stage c full
elif w_shape[0] == 2048: #stage c full
unet_config['c_cond'] = 2048
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
unet_config['stable_cascade_stage'] = 'b'
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
if w.shape[-1] == 640:
w_shape = sd_shape(state_dict, '{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix))
if w_shape[-1] == 640:
unet_config['c_hidden'] = [320, 640, 1280, 1280]
unet_config['nhead'] = [-1, -1, 20, 20]
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
elif w.shape[-1] == 576: #stage b lite
elif w_shape[-1] == 576: #stage b lite
unet_config['c_hidden'] = [320, 576, 1152, 1152]
unet_config['nhead'] = [-1, 9, 18, 18]
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
@ -113,8 +116,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
unet_config["max_seq"] = sd_shape(state_dict, '{}positional_encoding'.format(key_prefix))[1]
unet_config["cond_seq_dim"] = sd_shape(state_dict, '{}cond_seq_linear.weight'.format(key_prefix))[1]
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
unet_config["n_double_layers"] = double_layers
@ -125,10 +128,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config = {}
unet_config["image_model"] = "hydit"
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
unet_config["hidden_size"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0]
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
unet_config["mlp_ratio"] = 4.3637
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
if sd_shape(state_dict, '{}extra_embedder.0.weight'.format(key_prefix))[1] == 3968:
unet_config["size_cond"] = True
unet_config["use_style_cond"] = True
unet_config["image_model"] = "hydit1"
@ -136,12 +139,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
in_w_shape = sd_shape(state_dict, '{}img_in.proj.weight'.format(key_prefix))
out_w_shape = sd_shape(state_dict, '{}final_layer.linear.weight'.format(key_prefix))
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = list(in_w.shape[2:])
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
dit_config["in_channels"] = in_w_shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = list(in_w_shape[2:])
dit_config["out_channels"] = out_w_shape[0] // math.prod(dit_config["patch_size"])
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["vec_in_dim"] = 768
else:
@ -157,10 +160,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["meanflow"] = False
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_w.shape[0]
dit_config["context_in_dim"] = sd_shape(state_dict, '{}txt_in.input_embedder.weight'.format(key_prefix))[1]
dit_config["hidden_size"] = in_w_shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = in_w.shape[0] // 128
dit_config["num_heads"] = in_w_shape[0] // 128
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["theta"] = 256
@ -179,7 +182,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
dit_config["vision_in_dim"] = sd_shape(state_dict, '{}vision_in.proj.0.weight'.format(key_prefix))[0]
dit_config["meanflow_sum"] = True
else:
dit_config["vision_in_dim"] = None
@ -221,19 +224,19 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
w = state_dict[in_key]
dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
dit_config["hidden_size"] = w.shape[0]
w_shape = sd_shape(state_dict, in_key)
dit_config["in_channels"] = w_shape[1] // (patch_size * patch_size)
dit_config["hidden_size"] = w_shape[0]
txt_in_key = "{}txt_in.weight".format(key_prefix)
if txt_in_key in state_dict_keys:
w = state_dict[txt_in_key]
dit_config["context_in_dim"] = w.shape[1]
dit_config["hidden_size"] = w.shape[0]
w_shape = sd_shape(state_dict, txt_in_key)
dit_config["context_in_dim"] = w_shape[1]
dit_config["hidden_size"] = w_shape[0]
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
dit_config["vec_in_dim"] = sd_shape(state_dict, vec_in_key)[1]
else:
dit_config["vec_in_dim"] = None
@ -307,7 +310,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
shape = sd_shape(state_dict, '{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix))
dit_config["attention_head_dim"] = shape[0] // 32
dit_config["cross_attention_dim"] = shape[1]
if metadata is not None and "config" in metadata:
@ -350,11 +353,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
y_key = "{}y_embedder.y_embedding".format(key_prefix)
if y_key in state_dict_keys:
dit_config["model_max_length"] = state_dict[y_key].shape[0]
dit_config["model_max_length"] = sd_shape(state_dict, y_key)[0]
pe_key = "{}pos_embed".format(key_prefix)
if pe_key in state_dict_keys:
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
dit_config["input_size"] = int(math.sqrt(sd_shape(state_dict, pe_key)[1])) * patch_size
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
@ -373,11 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
dit_config["model_channels"] = sd_shape(state_dict, '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix))[0]
dit_config["block_config"] = "FA-CA-MLP"
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["pos_emb_cls"] = "rope3d"
@ -416,9 +419,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
dit_config["dim"] = w.shape[0]
dit_config["cap_feat_dim"] = w.shape[1]
w_shape = sd_shape(state_dict, '{}cap_embedder.1.weight'.format(key_prefix))
dit_config["dim"] = w_shape[0]
dit_config["cap_feat_dim"] = w_shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["qk_norm"] = True
@ -429,9 +432,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None: # NewBie
dit_config["clip_text_dim"] = ctd_weight.shape[0]
ctd_key = '{}clip_text_pooled_proj.0.weight'.format(key_prefix)
if ctd_key in state_dict_keys: # NewBie
dit_config["clip_text_dim"] = sd_shape(state_dict, ctd_key)[0]
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
@ -450,12 +453,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
dim = sd_shape(state_dict, '{}head.modulation'.format(key_prefix))[-1]
out_dim = sd_shape(state_dict, '{}head.head.weight'.format(key_prefix))[0] // 4
dit_config["dim"] = dim
dit_config["out_dim"] = out_dim
dit_config["num_heads"] = dim // 128
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
dit_config["ffn_dim"] = sd_shape(state_dict, '{}blocks.0.ffn.0.weight'.format(key_prefix))[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
dit_config["patch_size"] = (1, 2, 2)
dit_config["freq_dim"] = 256
@ -463,10 +466,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["qk_norm"] = True
dit_config["cross_attn_norm"] = True
dit_config["eps"] = 1e-6
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["in_dim"] = sd_shape(state_dict, '{}patch_embedding.weight'.format(key_prefix))[1]
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "vace"
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_in_dim"] = sd_shape(state_dict, '{}vace_patch_embedding.weight'.format(key_prefix))[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
@ -484,22 +487,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
if flf_weight is not None:
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
flf_key = '{}img_emb.emb_pos'.format(key_prefix)
if flf_key in state_dict_keys:
dit_config["flf_pos_embed_token_number"] = sd_shape(state_dict, flf_key)[1]
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
if ref_conv_weight is not None:
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
ref_conv_key = '{}ref_conv.weight'.format(key_prefix)
if ref_conv_key in state_dict_keys:
dit_config["in_dim_ref_conv"] = sd_shape(state_dict, ref_conv_key)[1]
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
in_shape = sd_shape(state_dict, '{}latent_in.weight'.format(key_prefix))
dit_config = {}
dit_config["image_model"] = "hunyuan3d2"
dit_config["in_channels"] = in_shape[1]
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
dit_config["context_in_dim"] = sd_shape(state_dict, '{}cond_in.weight'.format(key_prefix))[1]
dit_config["hidden_size"] = in_shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
@ -513,9 +516,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1"
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
dit_config["in_channels"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[1]
dit_config["context_dim"] = 1024
dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
dit_config["hidden_size"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
@ -549,11 +552,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
dit_config["model_channels"] = sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"
@ -617,7 +620,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
dit_config["in_channels"] = sd_shape(state_dict, '{}img_in.weight'.format(key_prefix))[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
@ -628,7 +631,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
model_dim = sd_shape(state_dict, '{}visual_embeddings.in_layer.bias'.format(key_prefix))[0]
dit_config["model_dim"] = model_dim
if model_dim in [4096, 2560]: # pro video and lite image
dit_config["axes_dims"] = (32, 48, 48)
@ -636,10 +639,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
elif model_dim == 1792: # lite video
dit_config["axes_dims"] = (16, 24, 24)
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
dit_config["time_dim"] = sd_shape(state_dict, '{}time_embeddings.in_layer.bias'.format(key_prefix))[0]
dit_config["image_model"] = "kandinsky5"
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
dit_config["ff_dim"] = sd_shape(state_dict, '{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix))[0]
dit_config["visual_embed_dim"] = sd_shape(state_dict, '{}visual_embeddings.in_layer.weight'.format(key_prefix))[1]
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
@ -657,16 +660,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
y_input = '{}label_emb.0.0.weight'.format(key_prefix)
if y_input in state_dict_keys:
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
unet_config["adm_in_channels"] = sd_shape(state_dict, y_input)[1]
else:
unet_config["adm_in_channels"] = None
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
model_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[0]
in_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[1]
out_key = '{}out.2.weight'.format(key_prefix)
if out_key in state_dict:
out_channels = state_dict[out_key].shape[0]
out_channels = sd_shape(state_dict, out_key)[0]
else:
out_channels = 4
@ -713,7 +716,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
if res_block_prefix in block_keys:
last_res_blocks += 1
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
last_channel_mult = sd_shape(state_dict, "{}0.out_layers.3.weight".format(prefix))[0] // model_channels
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
if out is not None:
@ -867,7 +870,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
match["context_dim"] = sd_shape(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab))[1]
attn_res *= 2
if attn_blocks == 0:
@ -876,13 +879,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
match["transformer_depth"] = transformer_depth
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
match["model_channels"] = sd_shape(state_dict, "conv_in.weight")[0]
match["in_channels"] = sd_shape(state_dict, "conv_in.weight")[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
match["adm_in_channels"] = sd_shape(state_dict, "class_embedding.linear_1.weight")[1]
elif "add_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
match["adm_in_channels"] = sd_shape(state_dict, "add_embedding.linear_1.weight")[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
@ -1023,11 +1026,11 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
hidden_size = sd_shape(state_dict, "x_embedder.bias")[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
depth = sd_shape(state_dict, "pos_embed.proj.weight")[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
else:
return None

View File

@ -27,6 +27,7 @@ import platform
import weakref
import gc
import os
import comfy.disk_weights
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -583,6 +584,8 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
comfy.disk_weights.evict_ram_cache(memory_required)
unloaded_model = []
can_unload = []
unloaded_models = []
@ -1124,6 +1127,16 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
WEIGHTS_RAM_CACHE_BYTES = 0
WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
if args.weights_ram_cache_gb is not None:
WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
comfy.disk_weights.configure(
WEIGHTS_RAM_CACHE_BYTES,
allow_gds=WEIGHTS_GDS_ENABLED,
pin_if_cpu=not args.disable_pinned_memory,
)
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def discard_cuda_async_error():

View File

@ -34,6 +34,7 @@ import comfy.lora
import comfy.model_management
import comfy.patcher_extension
import comfy.utils
import comfy.disk_weights
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -269,6 +270,8 @@ class ModelPatcher:
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
comfy.disk_weights.attach_disk_weight_hooks(self.model)
def model_size(self):
if self.size > 0:
return self.size
@ -1356,4 +1359,3 @@ class ModelPatcher:
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)

799
comfy/safetensors_stream.py Normal file
View File

@ -0,0 +1,799 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import collections
import importlib
import importlib.util
import os
import threading
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple
import torch
_FST_MODULE = None
_FST_LOCK = threading.Lock()
_FST_LOADED = False
_GDS_INITIALIZED = False
_MISSING = object()
def _require_fastsafetensors():
global _FST_MODULE
with _FST_LOCK:
if _FST_MODULE is None:
if importlib.util.find_spec("fastsafetensors") is None:
raise ImportError(
"fastsafetensors is required for safetensors streaming. "
"Install it with: pip install 'fastsafetensors @ https://github.com/"
"foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip'"
)
_FST_MODULE = importlib.import_module("fastsafetensors")
return _FST_MODULE
def _init_fastsafetensors_lib():
global _FST_LOADED
fst = _require_fastsafetensors()
if not _FST_LOADED:
fst.cpp.load_library_functions()
_FST_LOADED = True
return fst
def _init_gds():
global _GDS_INITIALIZED
fst = _init_fastsafetensors_lib()
if not _GDS_INITIALIZED:
if fst.cpp.init_gds() != 0:
raise RuntimeError("fastsafetensors init_gds() failed")
_GDS_INITIALIZED = True
@dataclass(frozen=True)
class TensorMeta:
dtype: torch.dtype
shape: Tuple[int, ...]
numel: int
nbytes: int
data_offsets: Tuple[int, int]
filename: str
fst_dtype: object
strides: Tuple[int, ...]
class SafeTensorIndex:
def __init__(self, filename: str):
fst = _init_fastsafetensors_lib()
framework = fst.frameworks.get_framework_op("pytorch")
metadata = fst.common.SafeTensorsMetadata.from_file(filename, framework)
self._filename = filename
self._metadata = metadata
self._framework = framework
from fastsafetensors.frameworks import _torch as fst_torch
self._dtype_map = fst_torch.dtype_convert
self._tensor_meta: Dict[str, TensorMeta] = {}
for key, frame in metadata.tensors.items():
torch_dtype = self._dtype_map.get(frame.dtype, None)
if torch_dtype is None:
raise ValueError(f"Unsupported safetensors dtype {frame.dtype} in {filename}")
numel = 1
for s in frame.shape:
numel *= s
nbytes = numel * framework.get_dtype_size(frame.dtype)
self._tensor_meta[key] = TensorMeta(
dtype=torch_dtype,
shape=tuple(frame.shape),
numel=numel,
nbytes=nbytes,
data_offsets=(frame.data_offsets[0], frame.data_offsets[1]),
filename=filename,
fst_dtype=frame.dtype,
strides=tuple(frame.strides),
)
def keys(self) -> Iterable[str]:
return self._tensor_meta.keys()
def has(self, key: str) -> bool:
return key in self._tensor_meta
def meta(self, key: str) -> TensorMeta:
return self._tensor_meta[key]
def metadata(self):
return self._metadata.metadata
@property
def header_length(self) -> int:
return self._metadata.header_length
@property
def size_bytes(self) -> int:
return self._metadata.size_bytes
class _SafeTensorFile:
def __init__(self, filename: str, index: SafeTensorIndex):
self.filename = filename
self.index = index
self._fd: Optional[int] = None
self._gds_handle = None
self._gds_reader = None
self._nogds_reader = None
self._refcount = 1
def acquire(self) -> "_SafeTensorFile":
self._refcount += 1
return self
def release(self):
self._refcount -= 1
if self._refcount <= 0:
self.close()
def close(self):
if self._fd is not None:
os.close(self._fd)
self._fd = None
self._gds_handle = None
def _ensure_fd(self) -> int:
if self._fd is None:
self._fd = os.open(self.filename, os.O_RDONLY, 0o644)
return self._fd
def _ensure_nogds_reader(self, use_cuda: bool):
fst = _init_fastsafetensors_lib()
if self._nogds_reader is None:
self._nogds_reader = fst.cpp.nogds_file_reader(
False, 16 * 1024, 16, use_cuda
)
return self._nogds_reader
def _ensure_gds_reader(self, use_cuda: bool):
fst = _init_fastsafetensors_lib()
if self._gds_reader is None:
self._gds_reader = fst.cpp.gds_file_reader(16, use_cuda)
return self._gds_reader
def _ensure_gds_handle(self, use_cuda: bool):
if self._gds_handle is None:
fst = _init_fastsafetensors_lib()
framework = fst.frameworks.get_framework_op("pytorch")
o_direct = _get_gds_o_direct(framework)
self._gds_handle = fst.cpp.gds_file_handle(self.filename, o_direct, use_cuda)
return self._gds_handle
def read_tensor(
self,
meta: TensorMeta,
device: torch.device,
dtype: Optional[torch.dtype],
allow_gds: bool,
pin_if_cpu: bool,
) -> torch.Tensor:
fst = _init_fastsafetensors_lib()
framework = fst.frameworks.get_framework_op("pytorch")
device_is_cuda = device.type == "cuda"
if device_is_cuda and allow_gds:
_ensure_gds_ready(device)
tensor = self._read_tensor_gds(
fst, framework, meta, device, dtype
)
return tensor
cpu_tensor = self._read_tensor_nogds(
fst, framework, meta, torch.device("cpu"), dtype
)
if device_is_cuda:
if pin_if_cpu:
cpu_tensor = cpu_tensor.pin_memory()
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
return gpu_tensor
return cpu_tensor
def _aligned_range(self, abs_start: int, length: int) -> Tuple[int, int, int]:
fst = _init_fastsafetensors_lib()
align = fst.cpp.get_alignment_size()
aligned_offset = (abs_start // align) * align
head = abs_start - aligned_offset
aligned_length = length + head
tail = aligned_length % align
if tail:
aligned_length += align - tail
return aligned_offset, aligned_length, head
def _read_tensor_nogds(
self,
fst,
framework,
meta: TensorMeta,
device: torch.device,
dtype: Optional[torch.dtype],
) -> torch.Tensor:
fd = self._ensure_fd()
reader = self._ensure_nogds_reader(use_cuda=False)
abs_start = self.index.header_length + meta.data_offsets[0]
length = meta.data_offsets[1] - meta.data_offsets[0]
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
ptr_align = framework.get_device_ptr_align()
buffer_length = aligned_length + ptr_align
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
if reader.wait_read(req) < 0:
fst.cpp.cpu_free(buf_ptr)
raise RuntimeError("nogds_file_reader read failed")
owner = _BufferOwner(lambda: fst.cpp.cpu_free(buf_ptr))
tensor = _dlpack_tensor_from_buffer(
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
)
if dtype is not None and dtype != tensor.dtype:
_validate_dtype_conversion(tensor.dtype, dtype)
tensor = tensor.to(dtype=dtype)
return tensor
def _read_tensor_gds(
self,
fst,
framework,
meta: TensorMeta,
device: torch.device,
dtype: Optional[torch.dtype],
) -> torch.Tensor:
reader = self._ensure_gds_reader(use_cuda=True)
handle = self._ensure_gds_handle(use_cuda=True)
abs_start = self.index.header_length + meta.data_offsets[0]
length = meta.data_offsets[1] - meta.data_offsets[0]
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
ptr_align = framework.get_device_ptr_align()
buffer_length = aligned_length + ptr_align
fst_device = _fst_device_from_torch(fst, device)
gbuf = framework.alloc_tensor_memory(buffer_length, fst_device)
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
file_length = self.index.size_bytes
req = reader.submit_read(
handle, gbuf, aligned_offset, aligned_length, ptr_off, file_length
)
if reader.wait_read(req) < 0:
framework.free_tensor_memory(gbuf, fst_device)
raise RuntimeError("gds_file_reader read failed")
owner = _BufferOwner(lambda: framework.free_tensor_memory(gbuf, fst_device))
tensor = _dlpack_tensor_from_buffer(
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
)
if dtype is not None and dtype != tensor.dtype:
_validate_dtype_conversion(tensor.dtype, dtype)
tensor = tensor.to(dtype=dtype)
return tensor
def _fst_device_from_torch(fst, device: torch.device):
if device.type == "cuda" and device.index is not None:
return fst.st_types.Device.from_str(f"cuda:{device.index}")
return fst.st_types.Device.from_str(device.type)
class _BufferOwner:
def __init__(self, free_fn):
self._free_fn = free_fn
def __del__(self):
try:
self._free_fn()
except Exception:
pass
def _dlpack_tensor_from_buffer(
fst,
framework,
ptr: int,
meta: TensorMeta,
device: torch.device,
owner: Optional[_BufferOwner],
) -> torch.Tensor:
disk_dtype = framework.as_workaround_dtype(meta.fst_dtype)
dev = _fst_device_from_torch(fst, device)
dl_tensor = fst.dlpack.from_cuda_buffer(ptr, list(meta.shape), list(meta.strides), disk_dtype, dev)
torch_tensor = framework.from_dlpack(dl_tensor, dev, disk_dtype).real_tensor
if disk_dtype != meta.fst_dtype:
torch_tensor = torch_tensor.view(meta.dtype)
if owner is not None:
torch_tensor._comfy_disk_buffer_owner = owner
return torch_tensor
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
def _get_gds_o_direct(framework) -> bool:
cuda_ver = framework.get_cuda_ver()
if cuda_ver and cuda_ver != "0.0":
ver_parts = cuda_ver.split("-", 1)
if len(ver_parts) == 2:
cudavers = list(map(int, ver_parts[1].split(".")))
if ver_parts[0] == "cuda":
return not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2))
return True
return True
def _ensure_gds_ready(device: torch.device):
fst = _init_fastsafetensors_lib()
if not fst.common.is_gpu_found():
raise RuntimeError(
"GPUDirect requested but GPU runtime library is missing. "
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
)
gds_supported = fst.cpp.is_gds_supported(device.index if device.index is not None else 0)
if gds_supported < 0:
raise RuntimeError(
"GPUDirect requested but is_gds_supported() failed. "
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
)
if not fst.cpp.is_cufile_found():
raise RuntimeError(
"GPUDirect requested but libcufile is missing. "
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
)
if gds_supported == 0:
raise RuntimeError(
"GPUDirect requested but GDS is unsupported on this platform. "
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
)
_init_gds()
class StreamStateDict(collections.abc.MutableMapping):
is_stream_state_dict = True
def __init__(
self,
index: SafeTensorIndex,
file: _SafeTensorFile,
device: torch.device,
allow_gds: bool = False,
):
self._index = index
self._file = file
self._device = device
self._allow_gds = allow_gds
self._overrides: Dict[str, torch.Tensor] = {}
self._deleted: set[str] = set()
@classmethod
def from_file(cls, filename: str, device: torch.device, allow_gds: bool = False) -> "StreamStateDict":
index = SafeTensorIndex(filename)
file = _SafeTensorFile(filename, index)
return cls(index, file, device, allow_gds=allow_gds)
def close(self):
if self._file is not None:
self._file.release()
self._file = None
def __del__(self):
try:
self.close()
except Exception:
pass
def meta(self, key: str) -> TensorMeta:
if key in self._overrides:
t = self._overrides[key]
numel = t.numel()
return TensorMeta(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
data_offsets=(0, numel * t.element_size()),
filename="<override>",
fst_dtype=None,
strides=tuple(t.stride()),
)
if key in self._deleted:
raise KeyError(key)
return self._index.meta(key)
def get_tensor(
self,
key: str,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
) -> torch.Tensor:
if key in self._overrides:
t = self._overrides[key]
if device is not None and t.device != device:
t = t.to(device=device)
if dtype is not None and t.dtype != dtype:
_validate_dtype_conversion(t.dtype, dtype)
t = t.to(dtype=dtype)
return t
if key in self._deleted:
raise KeyError(key)
if device is None:
device = self._device
if allow_gds is None:
allow_gds = self._allow_gds
meta = self._index.meta(key)
return self._file.read_tensor(meta, device, dtype, allow_gds, pin_if_cpu)
def __getitem__(self, key: str) -> torch.Tensor:
return self.get_tensor(key)
def __setitem__(self, key: str, value: torch.Tensor) -> None:
self._overrides[key] = value
self._deleted.discard(key)
def __delitem__(self, key: str) -> None:
if key in self._overrides:
del self._overrides[key]
return
if key in self._deleted:
raise KeyError(key)
if self._index.has(key):
self._deleted.add(key)
return
raise KeyError(key)
def __iter__(self) -> Iterator[str]:
for k in self._index.keys():
if k in self._deleted:
continue
if k in self._overrides:
continue
yield k
for k in self._overrides.keys():
yield k
def __len__(self) -> int:
base = len(self._index.keys())
return base - len(self._deleted) + len(self._overrides)
def __contains__(self, key: object) -> bool:
if not isinstance(key, str):
return False
if key in self._deleted:
return False
if key in self._overrides:
return True
return self._index.has(key)
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
if key in self._overrides:
return self._overrides.pop(key)
if key in self._deleted:
if default is _MISSING:
raise KeyError(key)
return default
if self._index.has(key):
self._deleted.add(key)
return self.get_tensor(key)
if default is _MISSING:
raise KeyError(key)
return default
def copy(self) -> "StreamStateDict":
new = StreamStateDict(self._index, self._file.acquire(), self._device, allow_gds=self._allow_gds)
new._overrides = dict(self._overrides)
new._deleted = set(self._deleted)
return new
def metadata(self):
return self._index.metadata()
class _BaseViewStateDict(MutableMapping):
is_stream_state_dict = True
def __init__(self, base: MutableMapping, mutate_base: bool = False):
self._base = base
self._mutate_base = mutate_base
self._overrides: Dict[str, torch.Tensor] = {}
self._deleted: set[str] = set()
def _resolve_base_key(self, key: str) -> Optional[str]:
return key
def _iter_base_keys(self) -> Iterable[str]:
return self._base.keys()
def get_tensor(
self,
key: str,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
) -> torch.Tensor:
if key in self._overrides:
t = self._overrides[key]
if device is not None and t.device != device:
t = t.to(device=device)
if dtype is not None and t.dtype != dtype:
_validate_dtype_conversion(t.dtype, dtype)
t = t.to(dtype=dtype)
return t
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
raise KeyError(key)
if hasattr(self._base, "get_tensor"):
return self._base.get_tensor(
base_key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
)
t = self._base[base_key]
if device is not None and t.device != device:
t = t.to(device=device)
if dtype is not None and t.dtype != dtype:
_validate_dtype_conversion(t.dtype, dtype)
t = t.to(dtype=dtype)
return t
def meta(self, key: str):
if key in self._overrides:
t = self._overrides[key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
raise KeyError(key)
if hasattr(self._base, "meta"):
return self._base.meta(base_key)
t = self._base[base_key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
def __getitem__(self, key: str) -> torch.Tensor:
return self.get_tensor(key)
def __setitem__(self, key: str, value: torch.Tensor) -> None:
base_key = self._resolve_base_key(key)
if self._mutate_base and base_key is not None and base_key in self._base:
self._base[base_key] = value
else:
self._overrides[key] = value
self._deleted.discard(key)
def __delitem__(self, key: str) -> None:
if key in self._overrides:
del self._overrides[key]
return
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
raise KeyError(key)
if self._mutate_base and base_key in self._base:
del self._base[base_key]
else:
self._deleted.add(key)
def __iter__(self) -> Iterator[str]:
for k in self._iter_base_keys():
if k in self._deleted:
continue
yield k
for k in self._overrides.keys():
yield k
def __len__(self) -> int:
base_keys = list(self._iter_base_keys())
return len(base_keys) - len(self._deleted) + len(self._overrides)
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
if key in self._overrides:
return self._overrides.pop(key)
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
if default is _MISSING:
raise KeyError(key)
return default
if self._mutate_base:
try:
return self._base.pop(base_key)
except KeyError:
if default is _MISSING:
raise
return default
self._deleted.add(key)
return self.get_tensor(key)
class FilterViewStateDict(_BaseViewStateDict):
def __init__(self, base: MutableMapping, predicate, mutate_base: bool = False):
super().__init__(base, mutate_base=mutate_base)
self._predicate = predicate
def _resolve_base_key(self, key: str) -> Optional[str]:
if self._predicate(key):
return key
return None
def _iter_base_keys(self) -> Iterable[str]:
for k in self._base.keys():
if self._predicate(k):
yield k
class PrefixViewStateDict(_BaseViewStateDict):
def __init__(self, base: MutableMapping, source_prefix: str, target_prefix: str = "", mutate_base: bool = False):
super().__init__(base, mutate_base=mutate_base)
self._source_prefix = source_prefix
self._target_prefix = target_prefix
self._mapping: Dict[str, str] = {}
self._reverse: Dict[str, str] = {}
for k in base.keys():
if not k.startswith(source_prefix):
continue
view_key = f"{target_prefix}{k[len(source_prefix):]}"
self._mapping[k] = view_key
self._reverse[view_key] = k
def _resolve_base_key(self, key: str) -> Optional[str]:
return self._reverse.get(key)
def _iter_base_keys(self) -> Iterable[str]:
return self._reverse.keys()
class RenameViewStateDict(_BaseViewStateDict):
def __init__(
self,
base: MutableMapping,
replace_prefix: Mapping[str, str],
filter_keys: bool = False,
mutate_base: bool = False,
):
super().__init__(base, mutate_base=mutate_base)
self._filter_keys = filter_keys
self._replace = list(replace_prefix.items())
self._mapping: Dict[str, str] = {}
self._reverse: Dict[str, str] = {}
for k in base.keys():
view_key = self._replace_key(k)
if view_key is None:
continue
self._mapping[k] = view_key
self._reverse[view_key] = k
def _replace_key(self, key: str) -> Optional[str]:
for rp, dst in self._replace:
if key.startswith(rp):
return f"{dst}{key[len(rp):]}"
if self._filter_keys:
return None
return key
def _resolve_base_key(self, key: str) -> Optional[str]:
return self._reverse.get(key)
def _iter_base_keys(self) -> Iterable[str]:
return self._reverse.keys()
class MergedStateDict(MutableMapping):
is_stream_state_dict = True
def __init__(self, *mappings: MutableMapping):
self._mappings = list(mappings)
self._overrides: Dict[str, torch.Tensor] = {}
self._deleted: set[str] = set()
def __getitem__(self, key: str) -> torch.Tensor:
if key in self._overrides:
return self._overrides[key]
if key in self._deleted:
raise KeyError(key)
for mapping in reversed(self._mappings):
if key in mapping:
if hasattr(mapping, "get_tensor"):
return mapping.get_tensor(key)
return mapping[key]
raise KeyError(key)
def __setitem__(self, key: str, value: torch.Tensor) -> None:
self._overrides[key] = value
self._deleted.discard(key)
def __delitem__(self, key: str) -> None:
if key in self._overrides:
del self._overrides[key]
return
if key in self._deleted:
raise KeyError(key)
if any(key in mapping for mapping in self._mappings):
self._deleted.add(key)
return
raise KeyError(key)
def __iter__(self) -> Iterator[str]:
seen = set()
for mapping in self._mappings:
for key in mapping.keys():
if key in self._deleted or key in seen:
continue
seen.add(key)
yield key
for key in self._overrides.keys():
if key not in seen:
yield key
def __len__(self) -> int:
return len(list(self.__iter__()))
def meta(self, key: str):
if key in self._overrides:
t = self._overrides[key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
if key in self._deleted:
raise KeyError(key)
for mapping in reversed(self._mappings):
if key in mapping:
if hasattr(mapping, "meta"):
return mapping.meta(key)
t = mapping[key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
raise KeyError(key)
class MappedStateDict(_BaseViewStateDict):
def __init__(self, base: MutableMapping, key_map: Mapping[str, str], mutate_base: bool = False):
super().__init__(base, mutate_base=mutate_base)
self._base_to_view = dict(key_map)
self._view_to_base = {v: k for k, v in key_map.items()}
def _resolve_base_key(self, key: str) -> Optional[str]:
return self._view_to_base.get(key)
def _iter_base_keys(self) -> Iterable[str]:
return self._view_to_base.keys()

View File

@ -25,6 +25,7 @@ import math
import os
import comfy.utils
import comfy.safetensors_stream
from . import clip_vision
from . import gligen
@ -288,7 +289,7 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
return self.cond_stage_model.load_state_dict(sd, strict=False)
return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
@ -346,7 +347,7 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
self.latent_channels = sd_shape(sd, "taesd_decoder.1.weight")[1]
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA()
@ -361,25 +362,19 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["encoder.{}".format(k)] = sd[k]
sd = new_sd
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["previewer.{}".format(k)] = sd[k]
sd = new_sd
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "previewer."})
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
if sd_shape(sd, 'decoder.conv_in.weight')[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -389,9 +384,9 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
elif sd_shape(sd, 'decoder.conv_in.weight')[1] == 32 and len(sd_shape(sd, 'decoder.conv_in.weight')) == 5:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@ -414,7 +409,7 @@ class VAE:
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
if 'decoder.post_quant_conv.weight' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
@ -427,7 +422,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
@ -462,11 +457,11 @@ class VAE:
self.downscale_index_formula = (6, 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
tensor_conv1_shape = sd_shape(sd, "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight")
version = 0
if tensor_conv1.shape[0] == 512:
if tensor_conv1_shape[0] == 512:
version = 0
elif tensor_conv1.shape[0] == 1024:
elif tensor_conv1_shape[0] == 1024:
version = 1
if "encoder.down_blocks.1.conv.conv.bias" in sd:
version = 2
@ -483,9 +478,9 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
elif "decoder.conv_in.conv.weight" in sd and sd_shape(sd, 'decoder.conv_in.conv.weight')[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
self.latent_channels = 32
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@ -509,8 +504,8 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
#This is likely to significantly over-estimate with single image or low frame counts as the
#implementation is able to completely skip caching. Rework if used as an image only VAE
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
@ -543,14 +538,14 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
dim = sd["decoder.head.0.gamma"].shape[0]
dim = sd_shape(sd, "decoder.head.0.gamma")[0]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.output_channels = sd_shape(sd, "encoder.conv1.weight")[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
@ -626,7 +621,7 @@ class VAE:
self.working_dtypes = [torch.float32]
self.crop_input = False
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
self.latent_channels = sd["decoder.1.weight"].shape[1]
self.latent_channels = sd_shape(sd, "decoder.1.weight")[1]
self.latent_dim = 3
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@ -637,12 +632,12 @@ class VAE:
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
elif self.latent_channels == 32 and sd_shape(sd, "decoder.22.bias")[0] == 12: # lighttae_hv15
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
if comfy.utils.state_dict_meta(sd, "decoder.1.weight").dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
latent_format=comfy.latent_formats.HunyuanVideo
else:
latent_format=None # lighttaew2_1 doesn't need scaling
@ -662,7 +657,7 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
@ -983,9 +978,12 @@ def load_style_model(ckpt_path):
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
model.load_state_dict(model_data)
comfy.utils.load_state_dict(model, model_data, strict=True)
return StyleModel(model)
def sd_shape(state_dict, key):
return comfy.utils.state_dict_meta(state_dict, key).shape
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
@ -1055,16 +1053,16 @@ def detect_te_model(sd):
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
weight_shape = sd_shape(sd, "encoder.block.23.layer.1.DenseReluDense.wi_1.weight")
if weight_shape[-1] == 4096:
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
elif weight_shape[-1] == 2048:
return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
if weight.shape[0] == 384:
weight_shape = sd_shape(sd, 'encoder.block.0.layer.0.SelfAttention.k.weight')
if weight_shape[0] == 384:
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
@ -1074,19 +1072,19 @@ def detect_te_model(sd):
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
weight_shape = sd_shape(sd, 'model.layers.0.self_attn.k_proj.bias')
if weight_shape[0] == 256:
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
if weight_shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
weight_shape = sd_shape(sd, 'model.layers.0.post_attention_layernorm.weight')
if 'model.layers.0.self_attn.q_norm.weight' in sd:
if weight.shape[0] == 2560:
if weight_shape[0] == 2560:
return TEModel.QWEN3_4B
elif weight.shape[0] == 2048:
elif weight_shape[0] == 2048:
return TEModel.QWEN3_2B
if weight.shape[0] == 5120:
if weight_shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
else:
@ -1415,19 +1413,29 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
scaled_fp8_list.append(k[:-len("scaled_fp8")])
if len(scaled_fp8_list) > 0:
out_sd = {}
for k in sd:
skip = False
if comfy.utils.is_stream_state_dict(sd):
def _keep_key(k, prefixes=tuple(scaled_fp8_list)):
return not any(k.startswith(pref) for pref in prefixes)
out_sd = comfy.safetensors_stream.FilterViewStateDict(sd, _keep_key, mutate_base=False)
merged = out_sd
for pref in scaled_fp8_list:
skip = skip or k.startswith(pref)
if not skip:
out_sd[k] = sd[k]
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
merged = comfy.safetensors_stream.MergedStateDict(merged, quant_sd)
sd = merged
else:
out_sd = {}
for k in sd:
skip = False
for pref in scaled_fp8_list:
skip = skip or k.startswith(pref)
if not skip:
out_sd[k] = sd[k]
for pref in scaled_fp8_list:
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
for k in quant_sd:
out_sd[k] = quant_sd[k]
sd = out_sd
for pref in scaled_fp8_list:
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
for k in quant_sd:
out_sd[k] = quant_sd[k]
sd = out_sd
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
@ -1505,12 +1513,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
logging.warning("{} {}".format(diffusers_keys[k], k))
if comfy.utils.is_stream_state_dict(sd):
new_sd = comfy.safetensors_stream.MappedStateDict(sd, diffusers_keys)
else:
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)

View File

@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return self(tokens)
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
def parse_parentheses(string):
result = []
@ -430,8 +430,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
try:
if embed_path.lower().endswith(".safetensors"):
import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu")
embed = comfy.utils.load_torch_file(embed_path, safe_load=True)
else:
try:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")

View File

@ -56,9 +56,9 @@ class TAESD(nn.Module):
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None:
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
if decoder_path is not None:
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
@staticmethod
def scale_latents(x):

View File

@ -116,7 +116,7 @@ class LTXAVTEModel(torch.nn.Module):
if len(sdo) == 0:
sdo = sd
return self.load_state_dict(sdo, strict=False)
return comfy.utils.load_state_dict(self, sdo, strict=False)
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):

View File

@ -26,10 +26,13 @@ import numpy as np
from PIL import Image
import logging
import itertools
from types import SimpleNamespace
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
import json
from . import safetensors_stream
import comfy.disk_weights
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@ -61,15 +64,9 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
if return_metadata:
metadata = f.metadata()
sd = safetensors_stream.StreamStateDict.from_file(ckpt, device=device)
if return_metadata:
metadata = sd.metadata()
except Exception as e:
if len(e.args) > 0:
message = e.args[0]
@ -110,16 +107,16 @@ def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
params += w.nelement()
meta = state_dict_meta(sd, k)
params += meta.numel
return params
def weight_dtype(sd, prefix=""):
dtypes = {}
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
meta = state_dict_meta(sd, k)
dtypes[meta.dtype] = dtypes.get(meta.dtype, 0) + meta.numel
if len(dtypes) == 0:
return None
@ -133,6 +130,13 @@ def state_dict_key_replace(state_dict, keys_to_replace):
return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
if is_stream_state_dict(state_dict):
return safetensors_stream.RenameViewStateDict(
state_dict,
replace_prefix,
filter_keys=filter_keys,
mutate_base=not filter_keys,
)
if filter_keys:
out = {}
else:
@ -145,6 +149,77 @@ def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
return out
def is_stream_state_dict(state_dict) -> bool:
return getattr(state_dict, "is_stream_state_dict", False)
def state_dict_meta(state_dict, key):
if hasattr(state_dict, "meta"):
return state_dict.meta(key)
w = state_dict[key]
numel = w.numel()
return SimpleNamespace(
dtype=w.dtype,
shape=tuple(w.shape),
numel=numel,
nbytes=numel * w.element_size(),
)
def load_state_dict(model, state_dict, strict=False, assign=False):
if is_stream_state_dict(state_dict):
comfy.disk_weights.register_module_weights(model, state_dict)
comfy.disk_weights.attach_disk_weight_hooks(model)
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
return missing, unexpected
return model.load_state_dict(state_dict, strict=strict)
def stream_load_state_dict(model, state_dict, strict=False, assign=False):
if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
state_dict = state_dict.copy()
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(state_dict, "_metadata", None)
def load(module, local_state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
if assign:
local_metadata["assign_to_params_buffers"] = assign
module._load_from_state_dict(
local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
child_prefix = f"{prefix}{name}."
child_state_dict = safetensors_stream.FilterViewStateDict(
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
)
load(child, child_state_dict, child_prefix)
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible)
if out is not None:
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
load(model, state_dict)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
return missing_keys, unexpected_keys
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight",
@ -1217,46 +1292,82 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict[scaled_fp8_key]
scaled_fp8_dtype = scaled_fp8_weight.dtype
if is_stream_state_dict(state_dict):
scaled_meta = state_dict_meta(state_dict, scaled_fp8_key)
scaled_fp8_dtype = scaled_meta.dtype
scaled_fp8_weight_nelements = scaled_meta.numel
else:
scaled_fp8_weight = state_dict[scaled_fp8_key]
scaled_fp8_dtype = scaled_fp8_weight.dtype
scaled_fp8_weight_nelements = scaled_fp8_weight.nelement()
if scaled_fp8_dtype == torch.float32:
scaled_fp8_dtype = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
if scaled_fp8_weight_nelements == 2:
full_precision_matrix_mult = True
else:
full_precision_matrix_mult = False
out_sd = {}
layers = {}
for k in list(state_dict.keys()):
if k == scaled_fp8_key:
continue
if not k.startswith(model_prefix):
out_sd[k] = state_dict[k]
continue
k_out = k
w = state_dict.pop(k)
layer = None
if k_out.endswith(".scale_weight"):
layer = k_out[:-len(".scale_weight")]
k_out = "{}.weight_scale".format(layer)
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf
if k_out.endswith(".scale_input"):
layer = k_out[:-len(".scale_input")]
k_out = "{}.input_scale".format(layer)
if w.item() == 1.0:
if is_stream_state_dict(state_dict):
key_map = {}
for k in list(state_dict.keys()):
if k == scaled_fp8_key:
continue
if not k.startswith(model_prefix):
key_map[k] = k
continue
k_out = k
layer = None
if k_out.endswith(".scale_weight"):
layer = k_out[:-len(".scale_weight")]
k_out = "{}.weight_scale".format(layer)
out_sd[k_out] = w
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf
state_dict = out_sd
if k_out.endswith(".scale_input"):
layer = k_out[:-len(".scale_input")]
k_out = "{}.input_scale".format(layer)
scale_val = state_dict[k]
if scale_val.item() == 1.0:
continue
key_map[k] = k_out
state_dict = safetensors_stream.MappedStateDict(state_dict, key_map)
else:
out_sd = {}
for k in list(state_dict.keys()):
if k == scaled_fp8_key:
continue
if not k.startswith(model_prefix):
out_sd[k] = state_dict[k]
continue
k_out = k
w = state_dict.pop(k)
layer = None
if k_out.endswith(".scale_weight"):
layer = k_out[:-len(".scale_weight")]
k_out = "{}.weight_scale".format(layer)
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf
if k_out.endswith(".scale_input"):
layer = k_out[:-len(".scale_input")]
k_out = "{}.input_scale".format(layer)
if w.item() == 1.0:
continue
out_sd[k_out] = w
state_dict = out_sd
quant_metadata = {"layers": layers}
else:
quant_metadata = json.loads(metadata["_quantization_metadata"])

View File

@ -17,7 +17,6 @@ from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
@ -525,7 +524,7 @@ class LoadLatent:
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
latent = safetensors.torch.load_file(latent_path, device="cpu")
latent = comfy.utils.load_torch_file(latent_path, safe_load=True)
multiplier = 1.0
if "latent_format_version_0" not in latent:
multiplier = 1.0 / 0.18215

View File

@ -11,6 +11,7 @@ transformers>=4.50.3
tokenizers>=0.13.3
sentencepiece
safetensors>=0.4.2
fastsafetensors @ https://github.com/foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip
aiohttp>=3.11.8
yarl>=1.18.0
pyyaml

View File

@ -0,0 +1,146 @@
import os
import pytest
import importlib
import importlib.util
torch = pytest.importorskip("torch")
def _write_safetensors(tmp_path, tensors):
import safetensors.torch
path = os.path.join(tmp_path, "test.safetensors")
safetensors.torch.save_file(tensors, path)
return path
def test_stream_state_dict_meta_is_lazy(tmp_path, monkeypatch):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
calls = []
original = sd._file.read_tensor
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
calls.append(meta)
return original(meta, device, dtype, allow_gds, pin_if_cpu)
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
meta = sd.meta("a")
assert meta.shape == (2, 3)
assert meta.dtype == torch.float32
assert meta.numel == 6
assert calls == []
def test_stream_state_dict_getitem_loads_single_tensor(tmp_path, monkeypatch):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
calls = []
original = sd._file.read_tensor
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
calls.append(meta)
return original(meta, device, dtype, allow_gds, pin_if_cpu)
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
_ = sd["a"]
assert len(calls) == 1
assert calls[0].shape == (2, 3)
def test_stream_views_do_not_materialize(tmp_path, monkeypatch):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
path = _write_safetensors(tmp_path, {"prefix.a": torch.zeros((2, 3)), "other": torch.ones((4,))})
sd = comfy.utils.load_torch_file(path, safe_load=True)
calls = []
original = sd._file.read_tensor
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
calls.append(meta)
return original(meta, device, dtype, allow_gds, pin_if_cpu)
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
view = comfy.utils.state_dict_prefix_replace(sd, {"prefix.": ""}, filter_keys=True)
_ = list(view.keys())
assert calls == []
def test_stream_load_rss_small(tmp_path):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
psutil = pytest.importorskip("psutil")
process = psutil.Process()
size_elems = 4_000_000 # ~16MB float32
tensor = torch.zeros((size_elems,), dtype=torch.float32)
path = _write_safetensors(tmp_path, {"big": tensor})
rss_before = process.memory_info().rss
sd = comfy.utils.load_torch_file(path, safe_load=True)
rss_after = process.memory_info().rss
expected_size = tensor.numel() * tensor.element_size()
assert (rss_after - rss_before) < expected_size
_ = sd.meta("big")
def test_gds_path_errors_without_support(tmp_path, monkeypatch):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
device = torch.device("cuda")
if importlib.util.find_spec("fastsafetensors") is None:
fst = None
else:
fst = importlib.import_module("fastsafetensors")
gds_available = False
if fst is not None and torch.cuda.is_available():
gds_supported = fst.cpp.is_gds_supported(torch.cuda.current_device())
gds_available = bool(fst.cpp.is_cufile_found()) and gds_supported == 1
if not gds_available:
with pytest.raises(RuntimeError, match="GPUDirect requested"):
sd.get_tensor("a", device=device, allow_gds=True)
else:
def fail_nogds(*args, **kwargs):
raise AssertionError("nogds path used during GDS request")
monkeypatch.setattr(sd._file, "_read_tensor_nogds", fail_nogds)
t = sd.get_tensor("a", device=device, allow_gds=True)
assert t.device.type == "cuda"
def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
import comfy.disk_weights
prev_cache = comfy.disk_weights.CACHE.max_bytes
prev_gds = comfy.disk_weights.ALLOW_GDS
prev_pin = comfy.disk_weights.PIN_IF_CPU
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
try:
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
model = torch.nn.Linear(4, 4, bias=True)
comfy.utils.load_state_dict(model, sd, strict=False)
assert model.weight.device.type == "cpu"
assert model.weight.device.type != "meta"
finally:
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)