This commit is contained in:
ifilipis 2026-01-09 09:00:27 +02:00 committed by GitHub
commit 7d9b1b0885
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 2794 additions and 246 deletions

57
DESIGN.md Normal file
View File

@ -0,0 +1,57 @@
# Disk tier safetensors streaming design audit (ComfyUI)
## Mandatory research audit (verified call sites)
### ComfyUI load path + eager materialization sites
- `comfy/utils.py:load_torch_file` currently uses `safetensors.safe_open` and iterates all keys to build a full `sd` dict (eager tensor materialization). It also returns metadata only after reading all tensors.【F:comfy/utils.py†L58-L93】
- `comfy/utils.py:calculate_parameters` and `weight_dtype` iterate `sd.keys()` and then access `sd[k]` to compute `nelement()`/`dtype` (loads tensors).【F:comfy/utils.py†L109-L128】
- `comfy/utils.py:state_dict_prefix_replace` mutates dicts by `pop`+assignment (materializes if used on a streaming mapping).【F:comfy/utils.py†L135-L144】
- `comfy/model_base.py:BaseModel.load_model_weights` builds `to_load = {}` by iterating keys and popping tensors, then passes a fully materialized dict to `load_state_dict` (RAM spike).【F:comfy/model_base.py†L301-L318】
- `comfy/model_detection.py` reads `state_dict[key].shape` in many branches for detection (must be metadata-only). Example: `calculate_transformer_depth` and numerous `detect_unet_config` branches read shapes directly from `state_dict` values.【F:comfy/model_detection.py†L21-L200】
- `comfy/sd.py` loads checkpoints, then slices, renames, and computes parameters/dtypes by reading tensors (e.g., `calculate_parameters`, `weight_dtype`, `process_*_state_dict`, and special scaled-FP8 conversion that builds new dicts).【F:comfy/sd.py†L1304-L1519】
- Direct safetensors load outside `load_torch_file`: `comfy/sd1_clip.py:load_embed` and `nodes.py:LoadLatent.load` use `safetensors.torch.load_file`, bypassing the core loader.【F:comfy/sd1_clip.py†L432-L434】【F:nodes.py†L521-L529】
### FastSageTensors (fastsafetensors) capability audit
- Header parsing and metadata:
- `fastsafetensors/common.py:SafeTensorsMetadata` parses the header and builds per-tensor `TensorFrame` with `dtype`, `shape`, and `data_offsets` (no tensor allocation).【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L63-L187】
- `TensorFrame` stores dtype/shape/offsets and supports slicing metadata.【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L238-L338】
- GDS + no-GDS low-level readers:
- `fastsafetensors/cpp.pyi` exposes `gds_file_reader`, `gds_file_handle`, `nogds_file_reader`, `cpu_malloc`, `gpu_malloc`, and alignment helpers such as `get_alignment_size()`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L1-L43】
- GDS availability checks are in `fastsafetensors/cpp.pyi`: `is_gds_supported`, `is_cufile_found`, `cufile_version`, and `init_gds`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L36-L43】
- DLPack wrapping:
- `fastsafetensors/dlpack.py` provides `from_cuda_buffer()` which creates DLPack capsules for both CPU and GPU buffers via a device descriptor and is used for `torch.from_dlpack`.【F:../third_party/fastsafetensors-main/fastsafetensors/dlpack.py†L232-L239】
- Torch framework interop:
- `fastsafetensors/frameworks/_torch.py:TorchOp` provides `alloc_tensor_memory`/`free_tensor_memory`, dtype mapping, and uses `torch.from_dlpack` for wrapping raw pointers into tensors.【F:../third_party/fastsafetensors-main/fastsafetensors/frameworks/_torch.py†L131-L205】
### VRAM/RAM offload logic (for extension)
- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】
- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】
## Strategy summary (implemented)
### Streaming safetensors mapping (no full dict materialization)
- [x] Introduce a new module `comfy/safetensors_stream.py` with:
- [x] `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
- [x] `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
- [x] Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
### Range reads and tiering
- [x] Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
- [x] Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
- [x] Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
### Disk tier integration
- [x] Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
- [x] Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
- [x] Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
### Pipeline refactors
- [x] Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
- [x] Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
- [x] Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
- [x] Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
- [x] Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
### Tests and docs
- [x] Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
- [x] Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.

View File

@ -349,6 +349,14 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
### Disk tier for model weights
When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.
# Running

View File

@ -29,7 +29,7 @@ class AudioEncoderModel():
self.model_sample_rate = 16000
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return comfy.utils.load_state_dict(self.model, sd, strict=False)
def get_sd(self):
return self.model.state_dict()

View File

@ -114,6 +114,9 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")

View File

@ -48,7 +48,7 @@ class ClipVisionModel():
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return comfy.utils.load_state_dict(self.model, sd, strict=False)
def get_sd(self):
return self.model.state_dict()

View File

@ -25,6 +25,7 @@ import logging
import comfy.utils
import comfy.model_management
import comfy.model_detection
import comfy.disk_weights
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
@ -385,7 +386,7 @@ class ControlLora(ControlNet):
controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
self.control_model.to(comfy.model_management.get_torch_device())
comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
@ -439,7 +440,7 @@ def controlnet_config(sd, model_options={}):
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@ -473,9 +474,9 @@ def load_controlnet_mmdit(sd, model_options={}):
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
@ -748,9 +749,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
pass
w = WeightsLoader()
w.control_model = control_model
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
else:
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@ -816,8 +817,8 @@ class T2IAdapter(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None:
self.t2i_model.to(x_noisy.dtype)
self.t2i_model.to(self.device)
comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
comfy.disk_weights.module_to(self.t2i_model, self.device)
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu()
@ -874,7 +875,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
else:
return None
missing, unexpected = model_ad.load_state_dict(t2i_data)
missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
if len(missing) > 0:
logging.warning("t2i missing {}".format(missing))

1115
comfy/disk_weights.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@ import torch
from torch import nn
from .ldm.modules.attention import CrossAttention, FeedForward
import comfy.ops
import comfy.utils
ops = comfy.ops.manual_cast
@ -282,7 +283,7 @@ def load_gligen(sd):
gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
gated.load_state_dict(n_sd, strict=False)
comfy.utils.load_state_dict(gated, n_sd, strict=False)
output_list.append(gated)
if "position_net.null_positive_feature" in sd_k:
@ -293,7 +294,7 @@ def load_gligen(sd):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
w.load_state_dict(sd, strict=False)
comfy.utils.load_state_dict(w, sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen

View File

@ -1,4 +1,5 @@
import torch
import comfy.utils
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
@ -112,7 +113,7 @@ class HunyuanVideo15SRModel():
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
return comfy.utils.load_state_dict(self.model, sd, strict=True)
def get_sd(self):
return self.model.state_dict()

View File

@ -2,6 +2,7 @@ import json
from dataclasses import dataclass
import math
import torch
import comfy.utils
import torchaudio
import comfy.model_management
@ -153,8 +154,8 @@ class AudioVAE(torch.nn.Module):
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(

View File

@ -2,6 +2,7 @@ import logging
from typing import Optional
import torch
import comfy.utils
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
@ -152,7 +153,7 @@ class VAE(nn.Module):
return dec, posterior
def load_weights(self, src_dict) -> None:
self.load_state_dict(src_dict, strict=True)
comfy.utils.load_state_dict(self, src_dict, strict=True)
@property
def device(self) -> torch.device:
@ -355,4 +356,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')

View File

@ -56,6 +56,7 @@ import comfy.conds
import comfy.ops
from enum import Enum
from . import utils
from . import safetensors_stream
import comfy.latent_formats
import comfy.model_sampling
import math
@ -299,20 +300,21 @@ class BaseModel(torch.nn.Module):
return out
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
keys = list(sd.keys())
for k in keys:
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)
replace_prefix = {unet_prefix: ""} if unet_prefix else {}
if replace_prefix:
if utils.is_stream_state_dict(sd):
to_load = utils.state_dict_prefix_replace(sd, replace_prefix, filter_keys=True)
else:
to_load = safetensors_stream.RenameViewStateDict(sd, replace_prefix, filter_keys=True, mutate_base=False)
else:
to_load = sd
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
if len(u) > 0:
logging.warning("unet unexpected: {}".format(u))
del to_load
return self
def process_latent_in(self, latent):
@ -751,8 +753,8 @@ class StableAudio1(BaseModel):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
def extra_conds(self, **kwargs):
out = {}

View File

@ -19,6 +19,9 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def sd_shape(state_dict, key):
return comfy.utils.state_dict_meta(state_dict, key).shape
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@ -27,8 +30,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
context_dim = sd_shape(state_dict, '{}0.attn2.to_k.weight'.format(transformer_prefix))[1]
use_linear_in_transformer = len(sd_shape(state_dict, '{}1.proj_in.weight'.format(prefix))) == 2
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
@ -39,27 +42,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
unet_config = {}
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
unet_config["in_channels"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[1]
patch_size = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[2]
unet_config["patch_size"] = patch_size
final_layer = '{}final_layer.linear.weight'.format(key_prefix)
if final_layer in state_dict:
unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
unet_config["out_channels"] = sd_shape(state_dict, final_layer)[0] // (patch_size * patch_size)
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
unet_config["depth"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0] // 64
unet_config["input_size"] = None
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
if y_key in state_dict_keys:
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
unet_config["adm_in_channels"] = sd_shape(state_dict, y_key)[1]
context_key = '{}context_embedder.weight'.format(key_prefix)
if context_key in state_dict_keys:
in_features = state_dict[context_key].shape[1]
out_features = state_dict[context_key].shape[0]
in_features = sd_shape(state_dict, context_key)[1]
out_features = sd_shape(state_dict, context_key)[0]
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
num_patches_key = '{}pos_embed'.format(key_prefix)
if num_patches_key in state_dict_keys:
num_patches = state_dict[num_patches_key].shape[1]
num_patches = sd_shape(state_dict, num_patches_key)[1]
unet_config["num_patches"] = num_patches
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
@ -83,23 +86,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
if text_mapper_name in state_dict_keys:
unet_config['stable_cascade_stage'] = 'c'
w = state_dict[text_mapper_name]
if w.shape[0] == 1536: #stage c lite
w_shape = sd_shape(state_dict, text_mapper_name)
if w_shape[0] == 1536: #stage c lite
unet_config['c_cond'] = 1536
unet_config['c_hidden'] = [1536, 1536]
unet_config['nhead'] = [24, 24]
unet_config['blocks'] = [[4, 12], [12, 4]]
elif w.shape[0] == 2048: #stage c full
elif w_shape[0] == 2048: #stage c full
unet_config['c_cond'] = 2048
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
unet_config['stable_cascade_stage'] = 'b'
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
if w.shape[-1] == 640:
w_shape = sd_shape(state_dict, '{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix))
if w_shape[-1] == 640:
unet_config['c_hidden'] = [320, 640, 1280, 1280]
unet_config['nhead'] = [-1, -1, 20, 20]
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
elif w.shape[-1] == 576: #stage b lite
elif w_shape[-1] == 576: #stage b lite
unet_config['c_hidden'] = [320, 576, 1152, 1152]
unet_config['nhead'] = [-1, 9, 18, 18]
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
@ -113,8 +116,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
unet_config["max_seq"] = sd_shape(state_dict, '{}positional_encoding'.format(key_prefix))[1]
unet_config["cond_seq_dim"] = sd_shape(state_dict, '{}cond_seq_linear.weight'.format(key_prefix))[1]
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
unet_config["n_double_layers"] = double_layers
@ -125,10 +128,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config = {}
unet_config["image_model"] = "hydit"
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
unet_config["hidden_size"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0]
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
unet_config["mlp_ratio"] = 4.3637
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
if sd_shape(state_dict, '{}extra_embedder.0.weight'.format(key_prefix))[1] == 3968:
unet_config["size_cond"] = True
unet_config["use_style_cond"] = True
unet_config["image_model"] = "hydit1"
@ -136,12 +139,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
in_w_shape = sd_shape(state_dict, '{}img_in.proj.weight'.format(key_prefix))
out_w_shape = sd_shape(state_dict, '{}final_layer.linear.weight'.format(key_prefix))
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = list(in_w.shape[2:])
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
dit_config["in_channels"] = in_w_shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = list(in_w_shape[2:])
dit_config["out_channels"] = out_w_shape[0] // math.prod(dit_config["patch_size"])
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["vec_in_dim"] = 768
else:
@ -157,10 +160,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["meanflow"] = False
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_w.shape[0]
dit_config["context_in_dim"] = sd_shape(state_dict, '{}txt_in.input_embedder.weight'.format(key_prefix))[1]
dit_config["hidden_size"] = in_w_shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = in_w.shape[0] // 128
dit_config["num_heads"] = in_w_shape[0] // 128
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["theta"] = 256
@ -179,7 +182,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
dit_config["vision_in_dim"] = sd_shape(state_dict, '{}vision_in.proj.0.weight'.format(key_prefix))[0]
dit_config["meanflow_sum"] = True
else:
dit_config["vision_in_dim"] = None
@ -221,19 +224,19 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
w = state_dict[in_key]
dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
dit_config["hidden_size"] = w.shape[0]
w_shape = sd_shape(state_dict, in_key)
dit_config["in_channels"] = w_shape[1] // (patch_size * patch_size)
dit_config["hidden_size"] = w_shape[0]
txt_in_key = "{}txt_in.weight".format(key_prefix)
if txt_in_key in state_dict_keys:
w = state_dict[txt_in_key]
dit_config["context_in_dim"] = w.shape[1]
dit_config["hidden_size"] = w.shape[0]
w_shape = sd_shape(state_dict, txt_in_key)
dit_config["context_in_dim"] = w_shape[1]
dit_config["hidden_size"] = w_shape[0]
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
dit_config["vec_in_dim"] = sd_shape(state_dict, vec_in_key)[1]
else:
dit_config["vec_in_dim"] = None
@ -307,7 +310,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
shape = sd_shape(state_dict, '{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix))
dit_config["attention_head_dim"] = shape[0] // 32
dit_config["cross_attention_dim"] = shape[1]
if metadata is not None and "config" in metadata:
@ -350,11 +353,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
y_key = "{}y_embedder.y_embedding".format(key_prefix)
if y_key in state_dict_keys:
dit_config["model_max_length"] = state_dict[y_key].shape[0]
dit_config["model_max_length"] = sd_shape(state_dict, y_key)[0]
pe_key = "{}pos_embed".format(key_prefix)
if pe_key in state_dict_keys:
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
dit_config["input_size"] = int(math.sqrt(sd_shape(state_dict, pe_key)[1])) * patch_size
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
@ -373,11 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
dit_config["model_channels"] = sd_shape(state_dict, '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix))[0]
dit_config["block_config"] = "FA-CA-MLP"
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["pos_emb_cls"] = "rope3d"
@ -416,9 +419,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
dit_config["dim"] = w.shape[0]
dit_config["cap_feat_dim"] = w.shape[1]
w_shape = sd_shape(state_dict, '{}cap_embedder.1.weight'.format(key_prefix))
dit_config["dim"] = w_shape[0]
dit_config["cap_feat_dim"] = w_shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["qk_norm"] = True
@ -429,9 +432,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None: # NewBie
dit_config["clip_text_dim"] = ctd_weight.shape[0]
ctd_key = '{}clip_text_pooled_proj.0.weight'.format(key_prefix)
if ctd_key in state_dict_keys: # NewBie
dit_config["clip_text_dim"] = sd_shape(state_dict, ctd_key)[0]
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
@ -450,12 +453,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
dim = sd_shape(state_dict, '{}head.modulation'.format(key_prefix))[-1]
out_dim = sd_shape(state_dict, '{}head.head.weight'.format(key_prefix))[0] // 4
dit_config["dim"] = dim
dit_config["out_dim"] = out_dim
dit_config["num_heads"] = dim // 128
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
dit_config["ffn_dim"] = sd_shape(state_dict, '{}blocks.0.ffn.0.weight'.format(key_prefix))[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
dit_config["patch_size"] = (1, 2, 2)
dit_config["freq_dim"] = 256
@ -463,10 +466,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["qk_norm"] = True
dit_config["cross_attn_norm"] = True
dit_config["eps"] = 1e-6
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["in_dim"] = sd_shape(state_dict, '{}patch_embedding.weight'.format(key_prefix))[1]
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "vace"
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_in_dim"] = sd_shape(state_dict, '{}vace_patch_embedding.weight'.format(key_prefix))[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
@ -484,22 +487,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
if flf_weight is not None:
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
flf_key = '{}img_emb.emb_pos'.format(key_prefix)
if flf_key in state_dict_keys:
dit_config["flf_pos_embed_token_number"] = sd_shape(state_dict, flf_key)[1]
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
if ref_conv_weight is not None:
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
ref_conv_key = '{}ref_conv.weight'.format(key_prefix)
if ref_conv_key in state_dict_keys:
dit_config["in_dim_ref_conv"] = sd_shape(state_dict, ref_conv_key)[1]
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
in_shape = sd_shape(state_dict, '{}latent_in.weight'.format(key_prefix))
dit_config = {}
dit_config["image_model"] = "hunyuan3d2"
dit_config["in_channels"] = in_shape[1]
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
dit_config["context_in_dim"] = sd_shape(state_dict, '{}cond_in.weight'.format(key_prefix))[1]
dit_config["hidden_size"] = in_shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
@ -513,9 +516,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1"
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
dit_config["in_channels"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[1]
dit_config["context_dim"] = 1024
dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
dit_config["hidden_size"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
@ -549,11 +552,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
dit_config["model_channels"] = sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"
@ -617,7 +620,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
dit_config["in_channels"] = sd_shape(state_dict, '{}img_in.weight'.format(key_prefix))[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
@ -628,7 +631,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
model_dim = sd_shape(state_dict, '{}visual_embeddings.in_layer.bias'.format(key_prefix))[0]
dit_config["model_dim"] = model_dim
if model_dim in [4096, 2560]: # pro video and lite image
dit_config["axes_dims"] = (32, 48, 48)
@ -636,10 +639,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
elif model_dim == 1792: # lite video
dit_config["axes_dims"] = (16, 24, 24)
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
dit_config["time_dim"] = sd_shape(state_dict, '{}time_embeddings.in_layer.bias'.format(key_prefix))[0]
dit_config["image_model"] = "kandinsky5"
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
dit_config["ff_dim"] = sd_shape(state_dict, '{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix))[0]
dit_config["visual_embed_dim"] = sd_shape(state_dict, '{}visual_embeddings.in_layer.weight'.format(key_prefix))[1]
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
@ -657,16 +660,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
y_input = '{}label_emb.0.0.weight'.format(key_prefix)
if y_input in state_dict_keys:
unet_config["num_classes"] = "sequential"
unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
unet_config["adm_in_channels"] = sd_shape(state_dict, y_input)[1]
else:
unet_config["adm_in_channels"] = None
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
model_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[0]
in_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[1]
out_key = '{}out.2.weight'.format(key_prefix)
if out_key in state_dict:
out_channels = state_dict[out_key].shape[0]
out_channels = sd_shape(state_dict, out_key)[0]
else:
out_channels = 4
@ -713,7 +716,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
if res_block_prefix in block_keys:
last_res_blocks += 1
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
last_channel_mult = sd_shape(state_dict, "{}0.out_layers.3.weight".format(prefix))[0] // model_channels
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
if out is not None:
@ -867,7 +870,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
match["context_dim"] = sd_shape(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab))[1]
attn_res *= 2
if attn_blocks == 0:
@ -876,13 +879,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
match["transformer_depth"] = transformer_depth
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
match["model_channels"] = sd_shape(state_dict, "conv_in.weight")[0]
match["in_channels"] = sd_shape(state_dict, "conv_in.weight")[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
match["adm_in_channels"] = sd_shape(state_dict, "class_embedding.linear_1.weight")[1]
elif "add_embedding.linear_1.weight" in state_dict:
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
match["adm_in_channels"] = sd_shape(state_dict, "add_embedding.linear_1.weight")[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
@ -1023,11 +1026,11 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
hidden_size = sd_shape(state_dict, "x_embedder.bias")[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
depth = sd_shape(state_dict, "pos_embed.proj.weight")[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
else:
return None

View File

@ -26,6 +26,7 @@ import platform
import weakref
import gc
import os
import comfy.disk_weights
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -540,7 +541,12 @@ class LoadedModel:
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
offload_device = None
if comfy.disk_weights.disk_weights_enabled():
offload_device = torch.device("meta")
self.model.detach(unpatch_weights, offload_device=offload_device)
if offload_device is not None and offload_device.type == "meta":
logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
@ -594,6 +600,11 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
logging.info("RAM pressure: requested %.2f MB, free %.2f MB", memory_required / (1024 * 1024), get_free_memory(device) / (1024 * 1024))
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
if freed_cache < memory_required:
evict_ram_to_disk(memory_required - freed_cache)
unloaded_model = []
can_unload = []
unloaded_models = []
@ -629,6 +640,34 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache()
return unloaded_models
def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
if memory_to_free <= 0:
return 0
if not comfy.disk_weights.disk_weights_enabled():
return 0
freed = 0
can_unload = []
for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model not in keep_loaded and not shift_model.is_dead():
loaded_memory = shift_model.model_loaded_memory()
if loaded_memory > 0:
can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
for x in sorted(can_unload):
i = x[-1]
memory_needed = memory_to_free - freed
if memory_needed <= 0:
break
logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
if freed > 0:
logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
return freed
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@ -1135,6 +1174,16 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
WEIGHTS_RAM_CACHE_BYTES = 0
WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
if args.weights_ram_cache_gb is not None:
WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
comfy.disk_weights.configure(
WEIGHTS_RAM_CACHE_BYTES,
allow_gds=WEIGHTS_GDS_ENABLED,
pin_if_cpu=not args.disable_pinned_memory,
)
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def discard_cuda_async_error():
@ -1291,7 +1340,10 @@ def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if hasattr(dev, 'type') and dev.type == "meta":
mem_free_total = sys.maxsize
mem_free_torch = mem_free_total
elif hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:

View File

@ -34,6 +34,7 @@ import comfy.lora
import comfy.model_management
import comfy.patcher_extension
import comfy.utils
import comfy.disk_weights
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -269,6 +270,8 @@ class ModelPatcher:
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
comfy.disk_weights.attach_disk_weight_hooks(self.model)
def model_size(self):
if self.size > 0:
return self.size
@ -783,7 +786,7 @@ class ModelPatcher:
m.comfy_patched_weights = True
for x in load_completely:
x[2].to(device_to)
comfy.disk_weights.module_to(x[2], device_to)
for x in offloaded:
n = x[1]
@ -799,7 +802,7 @@ class ModelPatcher:
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
comfy.disk_weights.module_to(self.model, device_to)
mem_counter = self.model_size()
self.model.lowvram_patch_counter += patch_counter
@ -856,7 +859,7 @@ class ModelPatcher:
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
comfy.disk_weights.module_to(self.model, device_to, allow_materialize=False)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
@ -883,6 +886,9 @@ class ModelPatcher:
if len(unload_list) > 0:
NS = comfy.model_management.NUM_STREAMS
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
remaining_ram = None
if device_to is not None and comfy.model_management.is_device_cpu(device_to):
remaining_ram = comfy.model_management.get_free_memory(device_to)
for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
@ -916,7 +922,24 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
freed_bytes = module_mem
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
freed_bytes = comfy.disk_weights.offload_module_weights(m)
if freed_bytes == 0:
freed_bytes = module_mem
else:
if remaining_ram is not None and remaining_ram < module_mem and comfy.disk_weights.disk_weights_enabled():
logging.info("Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", n, module_mem / (1024 * 1024), remaining_ram / (1024 * 1024))
freed_bytes = comfy.disk_weights.offload_module_weights(m)
if freed_bytes == 0:
freed_bytes = module_mem
else:
if comfy.disk_weights.disk_weights_enabled():
comfy.disk_weights.move_module_tensors(m, device_to)
else:
m.to(device_to)
if remaining_ram is not None:
remaining_ram = max(0, remaining_ram - module_mem)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
@ -939,7 +962,7 @@ class ModelPatcher:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
memory_freed += freed_bytes
offload_buffer = max(offload_buffer, potential_offload)
offload_weight_factor.append(module_mem)
offload_weight_factor.pop(0)
@ -953,7 +976,8 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
target_label = "disk" if device_to is not None and device_to.type == "meta" else device_to
logging.info("Unloaded partially to {}: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(target_label, memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
@ -984,11 +1008,12 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used
def detach(self, unpatch_all=True):
def detach(self, unpatch_all=True, offload_device=None):
self.eject_model()
self.model_patches_to(self.offload_device)
target_device = self.offload_device if offload_device is None else offload_device
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
self.unpatch_model(target_device, unpatch_weights=unpatch_all)
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
callback(self, unpatch_all)
return self.model
@ -1358,4 +1383,3 @@ class ModelPatcher:
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)

View File

@ -19,6 +19,7 @@
import torch
import logging
import comfy.model_management
import comfy.disk_weights
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
@ -98,11 +99,35 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
weight_source = s.weight
bias_source = s.bias
if comfy.disk_weights.disk_weights_enabled():
if weight_source.device.type == "meta":
loaded = comfy.disk_weights.load_module_tensor(
s,
"weight",
device,
temporary=True,
dtype_override=dtype,
)
if loaded is not None:
weight_source = loaded
if bias_source is not None and bias_source.device.type == "meta":
loaded_bias = comfy.disk_weights.load_module_tensor(
s,
"bias",
device,
temporary=True,
dtype_override=bias_dtype,
)
if loaded_bias is not None:
bias_source = loaded_bias
weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_source is not None:
bias = comfy.model_management.cast_to(bias_source, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
@ -532,9 +557,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
if value.device.type != "meta":
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
@ -551,11 +577,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
if layer_conf is not None and layer_conf.device.type != "meta":
layer_conf = json.loads(layer_conf.numpy().tobytes())
elif layer_conf is not None:
layer_conf = None
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
if weight.device.type == "meta":
self.weight = torch.nn.Parameter(weight, requires_grad=False)
else:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
@ -601,10 +632,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
if weight.device.type == "meta":
self.weight = torch.nn.Parameter(weight, requires_grad=False)
else:
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
@ -614,7 +648,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
_v = state_dict.pop(param_key, None)
if _v is None:
continue
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
if _v.device.type == "meta":
self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False))
else:
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

934
comfy/safetensors_stream.py Normal file
View File

@ -0,0 +1,934 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import collections
import ctypes
import importlib
import importlib.util
import os
import threading
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple
import torch
_FST_MODULE = None
_FST_LOCK = threading.Lock()
_FST_LOADED = False
_GDS_INITIALIZED = False
_MISSING = object()
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
def _require_fastsafetensors():
global _FST_MODULE
with _FST_LOCK:
if _FST_MODULE is None:
if importlib.util.find_spec("fastsafetensors") is None:
raise ImportError(
"fastsafetensors is required for safetensors streaming. "
"Install it with: pip install 'fastsafetensors @ https://github.com/"
"foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip'"
)
_FST_MODULE = importlib.import_module("fastsafetensors")
return _FST_MODULE
def _init_fastsafetensors_lib():
global _FST_LOADED
fst = _require_fastsafetensors()
if not _FST_LOADED:
fst.cpp.load_library_functions()
_FST_LOADED = True
return fst
def _init_gds():
global _GDS_INITIALIZED
fst = _init_fastsafetensors_lib()
if not _GDS_INITIALIZED:
if fst.cpp.init_gds() != 0:
raise RuntimeError("fastsafetensors init_gds() failed")
_GDS_INITIALIZED = True
@dataclass(frozen=True)
class TensorMeta:
dtype: torch.dtype
shape: Tuple[int, ...]
numel: int
nbytes: int
data_offsets: Tuple[int, int]
filename: str
fst_dtype: object
strides: Tuple[int, ...]
class SafeTensorIndex:
def __init__(self, filename: str):
fst = _init_fastsafetensors_lib()
framework = fst.frameworks.get_framework_op("pytorch")
metadata = fst.common.SafeTensorsMetadata.from_file(filename, framework)
self._filename = filename
self._metadata = metadata
self._framework = framework
from fastsafetensors.frameworks import _torch as fst_torch
self._dtype_map = fst_torch.dtype_convert
self._tensor_meta: Dict[str, TensorMeta] = {}
for key, frame in metadata.tensors.items():
torch_dtype = self._dtype_map.get(frame.dtype, None)
if torch_dtype is None:
raise ValueError(f"Unsupported safetensors dtype {frame.dtype} in {filename}")
numel = 1
for s in frame.shape:
numel *= s
nbytes = numel * framework.get_dtype_size(frame.dtype)
self._tensor_meta[key] = TensorMeta(
dtype=torch_dtype,
shape=tuple(frame.shape),
numel=numel,
nbytes=nbytes,
data_offsets=(frame.data_offsets[0], frame.data_offsets[1]),
filename=filename,
fst_dtype=frame.dtype,
strides=tuple(frame.strides),
)
def keys(self) -> Iterable[str]:
return self._tensor_meta.keys()
def has(self, key: str) -> bool:
return key in self._tensor_meta
def meta(self, key: str) -> TensorMeta:
return self._tensor_meta[key]
def metadata(self):
return self._metadata.metadata
@property
def header_length(self) -> int:
return self._metadata.header_length
@property
def size_bytes(self) -> int:
return self._metadata.size_bytes
class _SafeTensorFile:
def __init__(self, filename: str, index: SafeTensorIndex):
self.filename = filename
self.index = index
self._fd: Optional[int] = None
self._gds_handle = None
self._gds_reader = None
self._nogds_reader = None
self._refcount = 1
def acquire(self) -> "_SafeTensorFile":
self._refcount += 1
return self
def release(self):
self._refcount -= 1
if self._refcount <= 0:
self.close()
def close(self):
if self._fd is not None:
os.close(self._fd)
self._fd = None
self._gds_handle = None
def _ensure_fd(self) -> int:
if self._fd is None:
self._fd = os.open(self.filename, os.O_RDONLY, 0o644)
return self._fd
def _ensure_nogds_reader(self, use_cuda: bool):
fst = _init_fastsafetensors_lib()
if self._nogds_reader is None:
self._nogds_reader = fst.cpp.nogds_file_reader(
False, 16 * 1024, 16, use_cuda
)
return self._nogds_reader
def _ensure_gds_reader(self, use_cuda: bool):
fst = _init_fastsafetensors_lib()
if self._gds_reader is None:
self._gds_reader = fst.cpp.gds_file_reader(16, use_cuda)
return self._gds_reader
def _ensure_gds_handle(self, use_cuda: bool):
if self._gds_handle is None:
fst = _init_fastsafetensors_lib()
framework = fst.frameworks.get_framework_op("pytorch")
o_direct = _get_gds_o_direct(framework)
self._gds_handle = fst.cpp.gds_file_handle(self.filename, o_direct, use_cuda)
return self._gds_handle
def read_tensor(
self,
meta: TensorMeta,
device: torch.device,
dtype: Optional[torch.dtype],
allow_gds: bool,
pin_if_cpu: bool,
) -> torch.Tensor:
fst = _init_fastsafetensors_lib()
framework = fst.frameworks.get_framework_op("pytorch")
device_is_cuda = device.type == "cuda"
if device_is_cuda and allow_gds:
_ensure_gds_ready(device)
tensor = self._read_tensor_gds(
fst, framework, meta, device, dtype
)
return tensor
cpu_tensor = self._read_tensor_nogds(
fst, framework, meta, torch.device("cpu"), dtype
)
if device_is_cuda:
if pin_if_cpu:
cpu_tensor = cpu_tensor.pin_memory()
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
return gpu_tensor
return cpu_tensor
def _aligned_range(self, abs_start: int, length: int) -> Tuple[int, int, int]:
fst = _init_fastsafetensors_lib()
align = fst.cpp.get_alignment_size()
aligned_offset = (abs_start // align) * align
head = abs_start - aligned_offset
aligned_length = length + head
tail = aligned_length % align
if tail:
aligned_length += align - tail
return aligned_offset, aligned_length, head
def _read_tensor_nogds(
self,
fst,
framework,
meta: TensorMeta,
device: torch.device,
dtype: Optional[torch.dtype],
) -> torch.Tensor:
fd = self._ensure_fd()
reader = self._ensure_nogds_reader(use_cuda=False)
abs_start = self.index.header_length + meta.data_offsets[0]
length = meta.data_offsets[1] - meta.data_offsets[0]
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
chunk_bytes = max(1, chunk_bytes)
ptr_align = framework.get_device_ptr_align()
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu")
buffer_length = 0
buf_ptr = None
gbuf = None
try:
chunk_offset = 0
while chunk_offset < length:
chunk_len = min(length - chunk_offset, chunk_bytes)
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
needed = aligned_length + ptr_align
if buf_ptr is None or needed > buffer_length:
if buf_ptr is not None:
fst.cpp.cpu_free(buf_ptr)
buffer_length = needed
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
if reader.wait_read(req) < 0:
raise RuntimeError("nogds_file_reader read failed")
src_ptr = gbuf.get_base_address() + ptr_off + head
dest_ptr = dest_tensor.data_ptr() + chunk_offset
ctypes.memmove(dest_ptr, src_ptr, chunk_len)
chunk_offset += chunk_len
except Exception:
if buf_ptr is not None:
fst.cpp.cpu_free(buf_ptr)
raise
if buf_ptr is not None:
fst.cpp.cpu_free(buf_ptr)
if dtype is not None and dtype != dest_tensor.dtype:
_validate_dtype_conversion(dest_tensor.dtype, dtype)
dest_tensor = dest_tensor.to(dtype=dtype)
return dest_tensor
def _read_tensor_gds(
self,
fst,
framework,
meta: TensorMeta,
device: torch.device,
dtype: Optional[torch.dtype],
) -> torch.Tensor:
reader = self._ensure_gds_reader(use_cuda=True)
handle = self._ensure_gds_handle(use_cuda=True)
abs_start = self.index.header_length + meta.data_offsets[0]
length = meta.data_offsets[1] - meta.data_offsets[0]
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
ptr_align = framework.get_device_ptr_align()
buffer_length = aligned_length + ptr_align
fst_device = _fst_device_from_torch(fst, device)
gbuf = framework.alloc_tensor_memory(buffer_length, fst_device)
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
file_length = self.index.size_bytes
req = reader.submit_read(
handle, gbuf, aligned_offset, aligned_length, ptr_off, file_length
)
if reader.wait_read(req) < 0:
framework.free_tensor_memory(gbuf, fst_device)
raise RuntimeError("gds_file_reader read failed")
owner = _BufferOwner(lambda: framework.free_tensor_memory(gbuf, fst_device))
tensor = _dlpack_tensor_from_buffer(
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
)
if dtype is not None and dtype != tensor.dtype:
_validate_dtype_conversion(tensor.dtype, dtype)
tensor = tensor.to(dtype=dtype)
return tensor
def _fst_device_from_torch(fst, device: torch.device):
if device.type == "cuda" and device.index is not None:
return fst.st_types.Device.from_str(f"cuda:{device.index}")
return fst.st_types.Device.from_str(device.type)
class _BufferOwner:
def __init__(self, free_fn):
self._free_fn = free_fn
def __del__(self):
try:
self._free_fn()
except Exception:
pass
def _dlpack_tensor_from_buffer(
fst,
framework,
ptr: int,
meta: TensorMeta,
device: torch.device,
owner: Optional[_BufferOwner],
) -> torch.Tensor:
disk_dtype = framework.as_workaround_dtype(meta.fst_dtype)
dev = _fst_device_from_torch(fst, device)
dl_tensor = fst.dlpack.from_cuda_buffer(ptr, list(meta.shape), list(meta.strides), disk_dtype, dev)
torch_tensor = framework.from_dlpack(dl_tensor, dev, disk_dtype).real_tensor
if disk_dtype != meta.fst_dtype:
torch_tensor = torch_tensor.view(meta.dtype)
if owner is not None:
torch_tensor._comfy_disk_buffer_owner = owner
return torch_tensor
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
def _get_gds_o_direct(framework) -> bool:
cuda_ver = framework.get_cuda_ver()
if cuda_ver and cuda_ver != "0.0":
ver_parts = cuda_ver.split("-", 1)
if len(ver_parts) == 2:
cudavers = list(map(int, ver_parts[1].split(".")))
if ver_parts[0] == "cuda":
return not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2))
return True
return True
def _ensure_gds_ready(device: torch.device):
fst = _init_fastsafetensors_lib()
if not fst.common.is_gpu_found():
raise RuntimeError(
"GPUDirect requested but GPU runtime library is missing. "
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
)
gds_supported = fst.cpp.is_gds_supported(device.index if device.index is not None else 0)
if gds_supported < 0:
raise RuntimeError(
"GPUDirect requested but is_gds_supported() failed. "
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
)
if not fst.cpp.is_cufile_found():
raise RuntimeError(
"GPUDirect requested but libcufile is missing. "
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
)
if gds_supported == 0:
raise RuntimeError(
"GPUDirect requested but GDS is unsupported on this platform. "
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
)
_init_gds()
class StreamStateDict(collections.abc.MutableMapping):
is_stream_state_dict = True
def __init__(
self,
index: SafeTensorIndex,
file: _SafeTensorFile,
device: torch.device,
allow_gds: bool = False,
):
self._index = index
self._file = file
self._device = device
self._allow_gds = allow_gds
self._overrides: Dict[str, torch.Tensor] = {}
self._deleted: set[str] = set()
@classmethod
def from_file(cls, filename: str, device: torch.device, allow_gds: bool = False) -> "StreamStateDict":
index = SafeTensorIndex(filename)
file = _SafeTensorFile(filename, index)
return cls(index, file, device, allow_gds=allow_gds)
def close(self):
if self._file is not None:
self._file.release()
self._file = None
def __del__(self):
try:
self.close()
except Exception:
pass
def meta(self, key: str) -> TensorMeta:
if key in self._overrides:
t = self._overrides[key]
numel = t.numel()
return TensorMeta(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
data_offsets=(0, numel * t.element_size()),
filename="<override>",
fst_dtype=None,
strides=tuple(t.stride()),
)
if key in self._deleted:
raise KeyError(key)
return self._index.meta(key)
def get_tensor(
self,
key: str,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
) -> torch.Tensor:
if key in self._overrides:
t = self._overrides[key]
if device is not None and t.device != device:
t = t.to(device=device)
if dtype is not None and t.dtype != dtype:
_validate_dtype_conversion(t.dtype, dtype)
t = t.to(dtype=dtype)
return t
if key in self._deleted:
raise KeyError(key)
if device is None:
device = self._device
if device.type == "meta":
meta = self._index.meta(key)
target_dtype = dtype or meta.dtype
if dtype is not None and dtype != meta.dtype:
_validate_dtype_conversion(meta.dtype, dtype)
return torch.empty(meta.shape, dtype=target_dtype, device="meta")
if allow_gds is None:
allow_gds = self._allow_gds
meta = self._index.meta(key)
return self._file.read_tensor(meta, device, dtype, allow_gds, pin_if_cpu)
def __getitem__(self, key: str) -> torch.Tensor:
return self.get_tensor(key)
def __setitem__(self, key: str, value: torch.Tensor) -> None:
self._overrides[key] = value
self._deleted.discard(key)
def __delitem__(self, key: str) -> None:
if key in self._overrides:
del self._overrides[key]
return
if key in self._deleted:
raise KeyError(key)
if self._index.has(key):
self._deleted.add(key)
return
raise KeyError(key)
def __iter__(self) -> Iterator[str]:
for k in self._index.keys():
if k in self._deleted:
continue
if k in self._overrides:
continue
yield k
for k in self._overrides.keys():
yield k
def __len__(self) -> int:
base = len(self._index.keys())
return base - len(self._deleted) + len(self._overrides)
def __contains__(self, key: object) -> bool:
if not isinstance(key, str):
return False
if key in self._deleted:
return False
if key in self._overrides:
return True
return self._index.has(key)
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
if key in self._overrides:
return self._overrides.pop(key)
if key in self._deleted:
if default is _MISSING:
raise KeyError(key)
return default
if self._index.has(key):
self._deleted.add(key)
return self.get_tensor(key)
if default is _MISSING:
raise KeyError(key)
return default
def copy(self) -> "StreamStateDict":
new = StreamStateDict(self._index, self._file.acquire(), self._device, allow_gds=self._allow_gds)
new._overrides = dict(self._overrides)
new._deleted = set(self._deleted)
return new
def metadata(self):
return self._index.metadata()
class _BaseViewStateDict(MutableMapping):
is_stream_state_dict = True
def __init__(self, base: MutableMapping, mutate_base: bool = False):
self._base = base
self._mutate_base = mutate_base
self._overrides: Dict[str, torch.Tensor] = {}
self._deleted: set[str] = set()
def _resolve_base_key(self, key: str) -> Optional[str]:
return key
def _iter_base_keys(self) -> Iterable[str]:
return self._base.keys()
def get_tensor(
self,
key: str,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
) -> torch.Tensor:
if key in self._overrides:
t = self._overrides[key]
if device is not None and t.device != device:
t = t.to(device=device)
if dtype is not None and t.dtype != dtype:
_validate_dtype_conversion(t.dtype, dtype)
t = t.to(dtype=dtype)
return t
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
raise KeyError(key)
if hasattr(self._base, "get_tensor"):
return self._base.get_tensor(
base_key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
)
t = self._base[base_key]
if device is not None and t.device != device:
t = t.to(device=device)
if dtype is not None and t.dtype != dtype:
_validate_dtype_conversion(t.dtype, dtype)
t = t.to(dtype=dtype)
return t
def __getitem__(self, key: str) -> torch.Tensor:
return self.get_tensor(key)
def __setitem__(self, key: str, value: torch.Tensor) -> None:
base_key = self._resolve_base_key(key)
if self._mutate_base and base_key is not None and base_key in self._base:
self._base[base_key] = value
else:
self._overrides[key] = value
self._deleted.discard(key)
def __delitem__(self, key: str) -> None:
if key in self._overrides:
del self._overrides[key]
return
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
raise KeyError(key)
if self._mutate_base and base_key in self._base:
del self._base[base_key]
else:
self._deleted.add(key)
def __iter__(self) -> Iterator[str]:
for k in self._iter_base_keys():
if k in self._deleted:
continue
yield k
for k in self._overrides.keys():
yield k
def __len__(self) -> int:
base_keys = list(self._iter_base_keys())
return len(base_keys) - len(self._deleted) + len(self._overrides)
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
if key in self._overrides:
return self._overrides.pop(key)
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
if default is _MISSING:
raise KeyError(key)
return default
if self._mutate_base:
try:
return self._base.pop(base_key)
except KeyError:
if default is _MISSING:
raise
return default
self._deleted.add(key)
return self.get_tensor(key)
def meta(self, key: str):
if key in self._overrides:
t = self._overrides[key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
raise KeyError(key)
if hasattr(self._base, "meta"):
return self._base.meta(base_key)
t = self._base[base_key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
class DeviceViewStateDict(_BaseViewStateDict):
def __init__(
self,
base: MutableMapping,
device: torch.device,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
mutate_base: bool = False,
):
super().__init__(base, mutate_base=mutate_base)
self._device = device
self._allow_gds = allow_gds
self._pin_if_cpu = pin_if_cpu
def get_tensor(
self,
key: str,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
) -> torch.Tensor:
device = self._device if device is None else device
allow_gds = self._allow_gds if allow_gds is None else allow_gds
pin_if_cpu = self._pin_if_cpu if not pin_if_cpu else pin_if_cpu
return super().get_tensor(
key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
)
def meta(self, key: str):
if key in self._overrides:
t = self._overrides[key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
raise KeyError(key)
if hasattr(self._base, "meta"):
return self._base.meta(base_key)
t = self._base[base_key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
def __getitem__(self, key: str) -> torch.Tensor:
return self.get_tensor(key)
def __setitem__(self, key: str, value: torch.Tensor) -> None:
base_key = self._resolve_base_key(key)
if self._mutate_base and base_key is not None and base_key in self._base:
self._base[base_key] = value
else:
self._overrides[key] = value
self._deleted.discard(key)
def __delitem__(self, key: str) -> None:
if key in self._overrides:
del self._overrides[key]
return
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
raise KeyError(key)
if self._mutate_base and base_key in self._base:
del self._base[base_key]
else:
self._deleted.add(key)
def __iter__(self) -> Iterator[str]:
for k in self._iter_base_keys():
if k in self._deleted:
continue
yield k
for k in self._overrides.keys():
yield k
def __len__(self) -> int:
base_keys = list(self._iter_base_keys())
return len(base_keys) - len(self._deleted) + len(self._overrides)
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
if key in self._overrides:
return self._overrides.pop(key)
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
if default is _MISSING:
raise KeyError(key)
return default
if self._mutate_base:
try:
return self._base.pop(base_key)
except KeyError:
if default is _MISSING:
raise
return default
self._deleted.add(key)
return self.get_tensor(key)
class FilterViewStateDict(_BaseViewStateDict):
def __init__(self, base: MutableMapping, predicate, mutate_base: bool = False):
super().__init__(base, mutate_base=mutate_base)
self._predicate = predicate
def _resolve_base_key(self, key: str) -> Optional[str]:
if self._predicate(key):
return key
return None
def _iter_base_keys(self) -> Iterable[str]:
for k in self._base.keys():
if self._predicate(k):
yield k
class PrefixViewStateDict(_BaseViewStateDict):
def __init__(self, base: MutableMapping, source_prefix: str, target_prefix: str = "", mutate_base: bool = False):
super().__init__(base, mutate_base=mutate_base)
self._source_prefix = source_prefix
self._target_prefix = target_prefix
self._mapping: Dict[str, str] = {}
self._reverse: Dict[str, str] = {}
for k in base.keys():
if not k.startswith(source_prefix):
continue
view_key = f"{target_prefix}{k[len(source_prefix):]}"
self._mapping[k] = view_key
self._reverse[view_key] = k
def _resolve_base_key(self, key: str) -> Optional[str]:
return self._reverse.get(key)
def _iter_base_keys(self) -> Iterable[str]:
return self._reverse.keys()
class RenameViewStateDict(_BaseViewStateDict):
def __init__(
self,
base: MutableMapping,
replace_prefix: Mapping[str, str],
filter_keys: bool = False,
mutate_base: bool = False,
):
super().__init__(base, mutate_base=mutate_base)
self._filter_keys = filter_keys
self._replace = list(replace_prefix.items())
self._mapping: Dict[str, str] = {}
self._reverse: Dict[str, str] = {}
for k in base.keys():
view_key = self._replace_key(k)
if view_key is None:
continue
self._mapping[k] = view_key
self._reverse[view_key] = k
def _replace_key(self, key: str) -> Optional[str]:
for rp, dst in self._replace:
if key.startswith(rp):
return f"{dst}{key[len(rp):]}"
if self._filter_keys:
return None
return key
def _resolve_base_key(self, key: str) -> Optional[str]:
return self._reverse.get(key)
def _iter_base_keys(self) -> Iterable[str]:
return self._reverse.keys()
class MergedStateDict(MutableMapping):
is_stream_state_dict = True
def __init__(self, *mappings: MutableMapping):
self._mappings = list(mappings)
self._overrides: Dict[str, torch.Tensor] = {}
self._deleted: set[str] = set()
def __getitem__(self, key: str) -> torch.Tensor:
if key in self._overrides:
return self._overrides[key]
if key in self._deleted:
raise KeyError(key)
for mapping in reversed(self._mappings):
if key in mapping:
if hasattr(mapping, "get_tensor"):
return mapping.get_tensor(key)
return mapping[key]
raise KeyError(key)
def __setitem__(self, key: str, value: torch.Tensor) -> None:
self._overrides[key] = value
self._deleted.discard(key)
def __delitem__(self, key: str) -> None:
if key in self._overrides:
del self._overrides[key]
return
if key in self._deleted:
raise KeyError(key)
if any(key in mapping for mapping in self._mappings):
self._deleted.add(key)
return
raise KeyError(key)
def __iter__(self) -> Iterator[str]:
seen = set()
for mapping in self._mappings:
for key in mapping.keys():
if key in self._deleted or key in seen:
continue
seen.add(key)
yield key
for key in self._overrides.keys():
if key not in seen:
yield key
def __len__(self) -> int:
return len(list(self.__iter__()))
def meta(self, key: str):
if key in self._overrides:
t = self._overrides[key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
if key in self._deleted:
raise KeyError(key)
for mapping in reversed(self._mappings):
if key in mapping:
if hasattr(mapping, "meta"):
return mapping.meta(key)
t = mapping[key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
raise KeyError(key)
class MappedStateDict(_BaseViewStateDict):
def __init__(self, base: MutableMapping, key_map: Mapping[str, str], mutate_base: bool = False):
super().__init__(base, mutate_base=mutate_base)
self._base_to_view = dict(key_map)
self._view_to_base = {v: k for k, v in key_map.items()}
def _resolve_base_key(self, key: str) -> Optional[str]:
return self._view_to_base.get(key)
def _iter_base_keys(self) -> Iterable[str]:
return self._view_to_base.keys()

View File

@ -25,6 +25,8 @@ import math
import os
import comfy.utils
import comfy.safetensors_stream
import comfy.disk_weights
from . import clip_vision
from . import gligen
@ -124,7 +126,7 @@ class CLIP:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
if params['device'] != offload_device:
self.cond_stage_model.to(offload_device)
comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
logging.warning("Had to shift TE back.")
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
@ -288,7 +290,7 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
return self.cond_stage_model.load_state_dict(sd, strict=False)
return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
@ -349,7 +351,7 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
self.latent_channels = sd_shape(sd, "taesd_decoder.1.weight")[1]
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA()
@ -364,25 +366,19 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["encoder.{}".format(k)] = sd[k]
sd = new_sd
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["previewer.{}".format(k)] = sd[k]
sd = new_sd
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "previewer."})
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
if sd_shape(sd, 'decoder.conv_in.weight')[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -392,9 +388,9 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
elif sd_shape(sd, 'decoder.conv_in.weight')[1] == 32 and len(sd_shape(sd, 'decoder.conv_in.weight')) == 5:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@ -417,7 +413,7 @@ class VAE:
self.downscale_ratio = 4
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
if 'decoder.post_quant_conv.weight' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
@ -430,7 +426,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
@ -465,11 +461,11 @@ class VAE:
self.downscale_index_formula = (6, 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
tensor_conv1_shape = sd_shape(sd, "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight")
version = 0
if tensor_conv1.shape[0] == 512:
if tensor_conv1_shape[0] == 512:
version = 0
elif tensor_conv1.shape[0] == 1024:
elif tensor_conv1_shape[0] == 1024:
version = 1
if "encoder.down_blocks.1.conv.conv.bias" in sd:
version = 2
@ -486,9 +482,9 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
elif "decoder.conv_in.conv.weight" in sd and sd_shape(sd, 'decoder.conv_in.conv.weight')[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
self.latent_channels = 32
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@ -512,8 +508,8 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
#This is likely to significantly over-estimate with single image or low frame counts as the
#implementation is able to completely skip caching. Rework if used as an image only VAE
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
@ -546,14 +542,14 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
dim = sd["decoder.head.0.gamma"].shape[0]
dim = sd_shape(sd, "decoder.head.0.gamma")[0]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.output_channels = sd_shape(sd, "encoder.conv1.weight")[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
@ -629,7 +625,7 @@ class VAE:
self.working_dtypes = [torch.float32]
self.crop_input = False
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
self.latent_channels = sd["decoder.1.weight"].shape[1]
self.latent_channels = sd_shape(sd, "decoder.1.weight")[1]
self.latent_dim = 3
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@ -640,12 +636,12 @@ class VAE:
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
elif self.latent_channels == 32 and sd_shape(sd, "decoder.22.bias")[0] == 12: # lighttae_hv15
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
if comfy.utils.state_dict_meta(sd, "decoder.1.weight").dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
latent_format=comfy.latent_formats.HunyuanVideo
else:
latent_format=None # lighttaew2_1 doesn't need scaling
@ -665,7 +661,7 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
@ -679,7 +675,7 @@ class VAE:
if dtype is None:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
@ -986,9 +982,12 @@ def load_style_model(ckpt_path):
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
model.load_state_dict(model_data)
comfy.utils.load_state_dict(model, model_data, strict=True)
return StyleModel(model)
def sd_shape(state_dict, key):
return comfy.utils.state_dict_meta(state_dict, key).shape
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
@ -1058,16 +1057,16 @@ def detect_te_model(sd):
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
weight_shape = sd_shape(sd, "encoder.block.23.layer.1.DenseReluDense.wi_1.weight")
if weight_shape[-1] == 4096:
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
elif weight_shape[-1] == 2048:
return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
if weight.shape[0] == 384:
weight_shape = sd_shape(sd, 'encoder.block.0.layer.0.SelfAttention.k.weight')
if weight_shape[0] == 384:
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
@ -1077,19 +1076,19 @@ def detect_te_model(sd):
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
weight_shape = sd_shape(sd, 'model.layers.0.self_attn.k_proj.bias')
if weight_shape[0] == 256:
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
if weight_shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
weight_shape = sd_shape(sd, 'model.layers.0.post_attention_layernorm.weight')
if 'model.layers.0.self_attn.q_norm.weight' in sd:
if weight.shape[0] == 2560:
if weight_shape[0] == 2560:
return TEModel.QWEN3_4B
elif weight.shape[0] == 2048:
elif weight_shape[0] == 2048:
return TEModel.QWEN3_2B
if weight.shape[0] == 5120:
if weight_shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
else:
@ -1418,19 +1417,29 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
scaled_fp8_list.append(k[:-len("scaled_fp8")])
if len(scaled_fp8_list) > 0:
out_sd = {}
for k in sd:
skip = False
if comfy.utils.is_stream_state_dict(sd):
def _keep_key(k, prefixes=tuple(scaled_fp8_list)):
return not any(k.startswith(pref) for pref in prefixes)
out_sd = comfy.safetensors_stream.FilterViewStateDict(sd, _keep_key, mutate_base=False)
merged = out_sd
for pref in scaled_fp8_list:
skip = skip or k.startswith(pref)
if not skip:
out_sd[k] = sd[k]
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
merged = comfy.safetensors_stream.MergedStateDict(merged, quant_sd)
sd = merged
else:
out_sd = {}
for k in sd:
skip = False
for pref in scaled_fp8_list:
skip = skip or k.startswith(pref)
if not skip:
out_sd[k] = sd[k]
for pref in scaled_fp8_list:
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
for k in quant_sd:
out_sd[k] = quant_sd[k]
sd = out_sd
for pref in scaled_fp8_list:
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
for k in quant_sd:
out_sd[k] = quant_sd[k]
sd = out_sd
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
@ -1508,12 +1517,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
logging.warning("{} {}".format(diffusers_keys[k], k))
if comfy.utils.is_stream_state_dict(sd):
new_sd = comfy.safetensors_stream.MappedStateDict(sd, diffusers_keys)
else:
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
@ -1538,7 +1550,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model = comfy.disk_weights.module_to(model, offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:

View File

@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return self(tokens)
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
def parse_parentheses(string):
result = []
@ -430,8 +430,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
try:
if embed_path.lower().endswith(".safetensors"):
import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu")
embed = comfy.utils.load_torch_file(embed_path, safe_load=True)
else:
try:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")

View File

@ -56,9 +56,9 @@ class TAESD(nn.Module):
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None:
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
if decoder_path is not None:
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
@staticmethod
def scale_latents(x):

View File

@ -119,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module):
if len(sdo) == 0:
sdo = sd
return self.load_state_dict(sdo, strict=False)
return comfy.utils.load_state_dict(self, sdo, strict=False)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0

View File

@ -26,10 +26,13 @@ import numpy as np
from PIL import Image
import logging
import itertools
from types import SimpleNamespace
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
import json
from . import safetensors_stream
import comfy.disk_weights
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@ -61,15 +64,9 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
if return_metadata:
metadata = f.metadata()
sd = safetensors_stream.StreamStateDict.from_file(ckpt, device=device)
if return_metadata:
metadata = sd.metadata()
except Exception as e:
if len(e.args) > 0:
message = e.args[0]
@ -110,16 +107,16 @@ def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
params += w.nelement()
meta = state_dict_meta(sd, k)
params += meta.numel
return params
def weight_dtype(sd, prefix=""):
dtypes = {}
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
meta = state_dict_meta(sd, k)
dtypes[meta.dtype] = dtypes.get(meta.dtype, 0) + meta.numel
if len(dtypes) == 0:
return None
@ -133,6 +130,13 @@ def state_dict_key_replace(state_dict, keys_to_replace):
return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
if is_stream_state_dict(state_dict):
return safetensors_stream.RenameViewStateDict(
state_dict,
replace_prefix,
filter_keys=filter_keys,
mutate_base=not filter_keys,
)
if filter_keys:
out = {}
else:
@ -145,6 +149,79 @@ def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
return out
def is_stream_state_dict(state_dict) -> bool:
return getattr(state_dict, "is_stream_state_dict", False)
def state_dict_meta(state_dict, key):
if hasattr(state_dict, "meta"):
return state_dict.meta(key)
w = state_dict[key]
numel = w.numel()
return SimpleNamespace(
dtype=w.dtype,
shape=tuple(w.shape),
numel=numel,
nbytes=numel * w.element_size(),
)
def load_state_dict(model, state_dict, strict=False, assign=False):
if is_stream_state_dict(state_dict):
if comfy.disk_weights.disk_weights_enabled():
return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
comfy.disk_weights.register_module_weights(model, state_dict)
comfy.disk_weights.attach_disk_weight_hooks(model)
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
return missing, unexpected
return model.load_state_dict(state_dict, strict=strict)
def stream_load_state_dict(model, state_dict, strict=False, assign=False):
if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
state_dict = state_dict.copy()
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(state_dict, "_metadata", None)
def load(module, local_state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
if assign:
local_metadata["assign_to_params_buffers"] = assign
module._load_from_state_dict(
local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
child_prefix = f"{prefix}{name}."
child_state_dict = safetensors_stream.FilterViewStateDict(
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
)
load(child, child_state_dict, child_prefix)
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible)
if out is not None:
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
load(model, state_dict)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
return missing_keys, unexpected_keys
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight",
@ -825,7 +902,10 @@ def copy_to_param(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
if prev.device.type == "meta":
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=prev.requires_grad))
else:
prev.data.copy_(value)
def get_attr(obj, attr: str):
"""Retrieves a nested attribute from an object using dot notation.
@ -1217,46 +1297,82 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict[scaled_fp8_key]
scaled_fp8_dtype = scaled_fp8_weight.dtype
if is_stream_state_dict(state_dict):
scaled_meta = state_dict_meta(state_dict, scaled_fp8_key)
scaled_fp8_dtype = scaled_meta.dtype
scaled_fp8_weight_nelements = scaled_meta.numel
else:
scaled_fp8_weight = state_dict[scaled_fp8_key]
scaled_fp8_dtype = scaled_fp8_weight.dtype
scaled_fp8_weight_nelements = scaled_fp8_weight.nelement()
if scaled_fp8_dtype == torch.float32:
scaled_fp8_dtype = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
if scaled_fp8_weight_nelements == 2:
full_precision_matrix_mult = True
else:
full_precision_matrix_mult = False
out_sd = {}
layers = {}
for k in list(state_dict.keys()):
if k == scaled_fp8_key:
continue
if not k.startswith(model_prefix):
out_sd[k] = state_dict[k]
continue
k_out = k
w = state_dict.pop(k)
layer = None
if k_out.endswith(".scale_weight"):
layer = k_out[:-len(".scale_weight")]
k_out = "{}.weight_scale".format(layer)
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf
if k_out.endswith(".scale_input"):
layer = k_out[:-len(".scale_input")]
k_out = "{}.input_scale".format(layer)
if w.item() == 1.0:
if is_stream_state_dict(state_dict):
key_map = {}
for k in list(state_dict.keys()):
if k == scaled_fp8_key:
continue
if not k.startswith(model_prefix):
key_map[k] = k
continue
k_out = k
layer = None
if k_out.endswith(".scale_weight"):
layer = k_out[:-len(".scale_weight")]
k_out = "{}.weight_scale".format(layer)
out_sd[k_out] = w
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf
state_dict = out_sd
if k_out.endswith(".scale_input"):
layer = k_out[:-len(".scale_input")]
k_out = "{}.input_scale".format(layer)
scale_val = state_dict[k]
if scale_val.item() == 1.0:
continue
key_map[k] = k_out
state_dict = safetensors_stream.MappedStateDict(state_dict, key_map)
else:
out_sd = {}
for k in list(state_dict.keys()):
if k == scaled_fp8_key:
continue
if not k.startswith(model_prefix):
out_sd[k] = state_dict[k]
continue
k_out = k
w = state_dict.pop(k)
layer = None
if k_out.endswith(".scale_weight"):
layer = k_out[:-len(".scale_weight")]
k_out = "{}.weight_scale".format(layer)
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf
if k_out.endswith(".scale_input"):
layer = k_out[:-len(".scale_input")]
k_out = "{}.input_scale".format(layer)
if w.item() == 1.0:
continue
out_sd[k_out] = w
state_dict = out_sd
quant_metadata = {"layers": layers}
else:
quant_metadata = json.loads(metadata["_quantization_metadata"])

View File

@ -17,7 +17,6 @@ from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
@ -526,7 +525,7 @@ class LoadLatent:
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
latent = safetensors.torch.load_file(latent_path, device="cpu")
latent = comfy.utils.load_torch_file(latent_path, safe_load=True)
multiplier = 1.0
if "latent_format_version_0" not in latent:
multiplier = 1.0 / 0.18215

View File

@ -11,6 +11,7 @@ transformers>=4.50.3
tokenizers>=0.13.3
sentencepiece
safetensors>=0.4.2
fastsafetensors @ https://github.com/foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip
aiohttp>=3.11.8
yarl>=1.18.0
pyyaml

View File

@ -0,0 +1,182 @@
import os
import pytest
import importlib
import importlib.util
torch = pytest.importorskip("torch")
def _write_safetensors(tmp_path, tensors):
import safetensors.torch
path = os.path.join(tmp_path, "test.safetensors")
safetensors.torch.save_file(tensors, path)
return path
def test_stream_state_dict_meta_is_lazy(tmp_path, monkeypatch):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
calls = []
original = sd._file.read_tensor
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
calls.append(meta)
return original(meta, device, dtype, allow_gds, pin_if_cpu)
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
meta = sd.meta("a")
assert meta.shape == (2, 3)
assert meta.dtype == torch.float32
assert meta.numel == 6
assert calls == []
def test_stream_state_dict_getitem_loads_single_tensor(tmp_path, monkeypatch):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
calls = []
original = sd._file.read_tensor
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
calls.append(meta)
return original(meta, device, dtype, allow_gds, pin_if_cpu)
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
_ = sd["a"]
assert len(calls) == 1
assert calls[0].shape == (2, 3)
def test_stream_views_do_not_materialize(tmp_path, monkeypatch):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
path = _write_safetensors(tmp_path, {"prefix.a": torch.zeros((2, 3)), "other": torch.ones((4,))})
sd = comfy.utils.load_torch_file(path, safe_load=True)
calls = []
original = sd._file.read_tensor
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
calls.append(meta)
return original(meta, device, dtype, allow_gds, pin_if_cpu)
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
view = comfy.utils.state_dict_prefix_replace(sd, {"prefix.": ""}, filter_keys=True)
_ = list(view.keys())
assert calls == []
def test_stream_load_rss_small(tmp_path):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
psutil = pytest.importorskip("psutil")
process = psutil.Process()
size_elems = 4_000_000 # ~16MB float32
tensor = torch.zeros((size_elems,), dtype=torch.float32)
path = _write_safetensors(tmp_path, {"big": tensor})
rss_before = process.memory_info().rss
sd = comfy.utils.load_torch_file(path, safe_load=True)
rss_after = process.memory_info().rss
expected_size = tensor.numel() * tensor.element_size()
assert (rss_after - rss_before) < expected_size
_ = sd.meta("big")
def test_gds_path_errors_without_support(tmp_path, monkeypatch):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
device = torch.device("cuda")
if importlib.util.find_spec("fastsafetensors") is None:
fst = None
else:
fst = importlib.import_module("fastsafetensors")
gds_available = False
if fst is not None and torch.cuda.is_available():
gds_supported = fst.cpp.is_gds_supported(torch.cuda.current_device())
gds_available = bool(fst.cpp.is_cufile_found()) and gds_supported == 1
if not gds_available:
with pytest.raises(RuntimeError, match="GPUDirect requested"):
sd.get_tensor("a", device=device, allow_gds=True)
else:
def fail_nogds(*args, **kwargs):
raise AssertionError("nogds path used during GDS request")
monkeypatch.setattr(sd._file, "_read_tensor_nogds", fail_nogds)
t = sd.get_tensor("a", device=device, allow_gds=True)
assert t.device.type == "cuda"
def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path):
if torch is None:
pytest.skip("torch not installed")
import comfy.utils
import comfy.disk_weights
prev_cache = comfy.disk_weights.CACHE.max_bytes
prev_gds = comfy.disk_weights.ALLOW_GDS
prev_pin = comfy.disk_weights.PIN_IF_CPU
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
try:
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
model = torch.nn.Linear(4, 4, bias=True)
comfy.utils.load_state_dict(model, sd, strict=False)
assert model.weight.device.type == "cpu"
assert model.weight.device.type != "meta"
finally:
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch):
if importlib.util.find_spec("fastsafetensors") is None:
pytest.skip("fastsafetensors not installed")
import comfy.utils
import comfy.disk_weights
prev_cache = comfy.disk_weights.CACHE.max_bytes
prev_gds = comfy.disk_weights.ALLOW_GDS
prev_pin = comfy.disk_weights.PIN_IF_CPU
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
try:
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
model = torch.nn.Linear(4, 4, bias=True)
calls = []
original = sd._file.read_tensor
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
calls.append(meta)
return original(meta, device, dtype, allow_gds, pin_if_cpu)
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
comfy.utils.load_state_dict(model, sd, strict=True)
assert model.weight.device.type == "meta"
assert calls == []
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"))
assert model.weight.device.type == "cpu"
assert len(calls) == 2
finally:
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)