mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-14 00:00:57 +08:00
Merge b93f165c2e into ec0a832acb
This commit is contained in:
commit
7d9b1b0885
57
DESIGN.md
Normal file
57
DESIGN.md
Normal file
@ -0,0 +1,57 @@
|
||||
# Disk tier safetensors streaming design audit (ComfyUI)
|
||||
|
||||
## Mandatory research audit (verified call sites)
|
||||
|
||||
### ComfyUI load path + eager materialization sites
|
||||
- `comfy/utils.py:load_torch_file` currently uses `safetensors.safe_open` and iterates all keys to build a full `sd` dict (eager tensor materialization). It also returns metadata only after reading all tensors.【F:comfy/utils.py†L58-L93】
|
||||
- `comfy/utils.py:calculate_parameters` and `weight_dtype` iterate `sd.keys()` and then access `sd[k]` to compute `nelement()`/`dtype` (loads tensors).【F:comfy/utils.py†L109-L128】
|
||||
- `comfy/utils.py:state_dict_prefix_replace` mutates dicts by `pop`+assignment (materializes if used on a streaming mapping).【F:comfy/utils.py†L135-L144】
|
||||
- `comfy/model_base.py:BaseModel.load_model_weights` builds `to_load = {}` by iterating keys and popping tensors, then passes a fully materialized dict to `load_state_dict` (RAM spike).【F:comfy/model_base.py†L301-L318】
|
||||
- `comfy/model_detection.py` reads `state_dict[key].shape` in many branches for detection (must be metadata-only). Example: `calculate_transformer_depth` and numerous `detect_unet_config` branches read shapes directly from `state_dict` values.【F:comfy/model_detection.py†L21-L200】
|
||||
- `comfy/sd.py` loads checkpoints, then slices, renames, and computes parameters/dtypes by reading tensors (e.g., `calculate_parameters`, `weight_dtype`, `process_*_state_dict`, and special scaled-FP8 conversion that builds new dicts).【F:comfy/sd.py†L1304-L1519】
|
||||
- Direct safetensors load outside `load_torch_file`: `comfy/sd1_clip.py:load_embed` and `nodes.py:LoadLatent.load` use `safetensors.torch.load_file`, bypassing the core loader.【F:comfy/sd1_clip.py†L432-L434】【F:nodes.py†L521-L529】
|
||||
|
||||
### FastSageTensors (fastsafetensors) capability audit
|
||||
- Header parsing and metadata:
|
||||
- `fastsafetensors/common.py:SafeTensorsMetadata` parses the header and builds per-tensor `TensorFrame` with `dtype`, `shape`, and `data_offsets` (no tensor allocation).【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L63-L187】
|
||||
- `TensorFrame` stores dtype/shape/offsets and supports slicing metadata.【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L238-L338】
|
||||
- GDS + no-GDS low-level readers:
|
||||
- `fastsafetensors/cpp.pyi` exposes `gds_file_reader`, `gds_file_handle`, `nogds_file_reader`, `cpu_malloc`, `gpu_malloc`, and alignment helpers such as `get_alignment_size()`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L1-L43】
|
||||
- GDS availability checks are in `fastsafetensors/cpp.pyi`: `is_gds_supported`, `is_cufile_found`, `cufile_version`, and `init_gds`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L36-L43】
|
||||
- DLPack wrapping:
|
||||
- `fastsafetensors/dlpack.py` provides `from_cuda_buffer()` which creates DLPack capsules for both CPU and GPU buffers via a device descriptor and is used for `torch.from_dlpack`.【F:../third_party/fastsafetensors-main/fastsafetensors/dlpack.py†L232-L239】
|
||||
- Torch framework interop:
|
||||
- `fastsafetensors/frameworks/_torch.py:TorchOp` provides `alloc_tensor_memory`/`free_tensor_memory`, dtype mapping, and uses `torch.from_dlpack` for wrapping raw pointers into tensors.【F:../third_party/fastsafetensors-main/fastsafetensors/frameworks/_torch.py†L131-L205】
|
||||
|
||||
### VRAM/RAM offload logic (for extension)
|
||||
- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】
|
||||
- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】
|
||||
|
||||
## Strategy summary (implemented)
|
||||
|
||||
### Streaming safetensors mapping (no full dict materialization)
|
||||
- [x] Introduce a new module `comfy/safetensors_stream.py` with:
|
||||
- [x] `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
|
||||
- [x] `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
|
||||
- [x] Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
|
||||
|
||||
### Range reads and tiering
|
||||
- [x] Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
|
||||
- [x] Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
|
||||
- [x] Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
|
||||
|
||||
### Disk tier integration
|
||||
- [x] Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
|
||||
- [x] Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
|
||||
- [x] Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
|
||||
|
||||
### Pipeline refactors
|
||||
- [x] Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
|
||||
- [x] Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
|
||||
- [x] Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
|
||||
- [x] Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
|
||||
- [x] Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
|
||||
|
||||
### Tests and docs
|
||||
- [x] Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
|
||||
- [x] Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.
|
||||
@ -349,6 +349,14 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
|
||||
| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
|
||||
|
||||
### Disk tier for model weights
|
||||
|
||||
When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
|
||||
|
||||
If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.
|
||||
|
||||
|
||||
# Running
|
||||
|
||||
@ -29,7 +29,7 @@ class AudioEncoderModel():
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -114,6 +114,9 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
|
||||
parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
|
||||
@ -48,7 +48,7 @@ class ClipVisionModel():
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -25,6 +25,7 @@ import logging
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
import comfy.disk_weights
|
||||
import comfy.model_patcher
|
||||
import comfy.ops
|
||||
import comfy.latent_formats
|
||||
@ -385,7 +386,7 @@ class ControlLora(ControlNet):
|
||||
controlnet_config["operations"] = control_lora_ops
|
||||
controlnet_config["dtype"] = dtype
|
||||
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
self.control_model.to(comfy.model_management.get_torch_device())
|
||||
comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
|
||||
diffusion_model = model.diffusion_model
|
||||
sd = diffusion_model.state_dict()
|
||||
|
||||
@ -439,7 +440,7 @@ def controlnet_config(sd, model_options={}):
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
@ -473,9 +474,9 @@ def load_controlnet_mmdit(sd, model_options={}):
|
||||
class ControlNetSD35(ControlNet):
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
if self.control_model.double_y_emb:
|
||||
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
else:
|
||||
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def copy(self):
|
||||
@ -748,9 +749,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.control_model = control_model
|
||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||
missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
|
||||
else:
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
@ -816,8 +817,8 @@ class T2IAdapter(ControlBase):
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
if self.control_input is None:
|
||||
self.t2i_model.to(x_noisy.dtype)
|
||||
self.t2i_model.to(self.device)
|
||||
comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
|
||||
comfy.disk_weights.module_to(self.t2i_model, self.device)
|
||||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||
self.t2i_model.cpu()
|
||||
|
||||
@ -874,7 +875,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
else:
|
||||
return None
|
||||
|
||||
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
||||
missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
|
||||
if len(missing) > 0:
|
||||
logging.warning("t2i missing {}".format(missing))
|
||||
|
||||
|
||||
1115
comfy/disk_weights.py
Normal file
1115
comfy/disk_weights.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,7 @@ import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention, FeedForward
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
|
||||
@ -282,7 +283,7 @@ def load_gligen(sd):
|
||||
|
||||
gated = GatedSelfAttentionDense(
|
||||
query_dim, key_dim, n_heads, d_head)
|
||||
gated.load_state_dict(n_sd, strict=False)
|
||||
comfy.utils.load_state_dict(gated, n_sd, strict=False)
|
||||
output_list.append(gated)
|
||||
|
||||
if "position_net.null_positive_feature" in sd_k:
|
||||
@ -293,7 +294,7 @@ def load_gligen(sd):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.position_net = PositionNet(in_dim, out_dim)
|
||||
w.load_state_dict(sd, strict=False)
|
||||
comfy.utils.load_state_dict(w, sd, strict=False)
|
||||
|
||||
gligen = Gligen(output_list, w.position_net, key_dim)
|
||||
return gligen
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
@ -112,7 +113,7 @@ class HunyuanVideo15SRModel():
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torchaudio
|
||||
|
||||
import comfy.model_management
|
||||
@ -153,8 +154,8 @@ class AudioVAE(torch.nn.Module):
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
|
||||
comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
|
||||
|
||||
autoencoder_config = self.autoencoder.get_config()
|
||||
self.normalizer = AudioLatentNormalizer(
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torch.nn as nn
|
||||
|
||||
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||
@ -152,7 +153,7 @@ class VAE(nn.Module):
|
||||
return dec, posterior
|
||||
|
||||
def load_weights(self, src_dict) -> None:
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
comfy.utils.load_state_dict(self, src_dict, strict=True)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
@ -355,4 +356,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
|
||||
if name == '44k':
|
||||
return VAE_44k(**kwargs)
|
||||
raise ValueError(f'Unknown model: {name}')
|
||||
|
||||
|
||||
@ -56,6 +56,7 @@ import comfy.conds
|
||||
import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
from . import safetensors_stream
|
||||
import comfy.latent_formats
|
||||
import comfy.model_sampling
|
||||
import math
|
||||
@ -299,20 +300,21 @@ class BaseModel(torch.nn.Module):
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
to_load = {}
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k.startswith(unet_prefix):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
replace_prefix = {unet_prefix: ""} if unet_prefix else {}
|
||||
if replace_prefix:
|
||||
if utils.is_stream_state_dict(sd):
|
||||
to_load = utils.state_dict_prefix_replace(sd, replace_prefix, filter_keys=True)
|
||||
else:
|
||||
to_load = safetensors_stream.RenameViewStateDict(sd, replace_prefix, filter_keys=True, mutate_base=False)
|
||||
else:
|
||||
to_load = sd
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.warning("unet unexpected: {}".format(u))
|
||||
del to_load
|
||||
return self
|
||||
|
||||
def process_latent_in(self, latent):
|
||||
@ -751,8 +753,8 @@ class StableAudio1(BaseModel):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
||||
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
|
||||
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
|
||||
utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
|
||||
utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
@ -19,6 +19,9 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def sd_shape(state_dict, key):
|
||||
return comfy.utils.state_dict_meta(state_dict, key).shape
|
||||
|
||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
@ -27,8 +30,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
|
||||
if len(transformer_keys) > 0:
|
||||
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||
context_dim = sd_shape(state_dict, '{}0.attn2.to_k.weight'.format(transformer_prefix))[1]
|
||||
use_linear_in_transformer = len(sd_shape(state_dict, '{}1.proj_in.weight'.format(prefix))) == 2
|
||||
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
||||
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
||||
@ -39,27 +42,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
||||
unet_config = {}
|
||||
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
||||
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
|
||||
unet_config["in_channels"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[1]
|
||||
patch_size = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[2]
|
||||
unet_config["patch_size"] = patch_size
|
||||
final_layer = '{}final_layer.linear.weight'.format(key_prefix)
|
||||
if final_layer in state_dict:
|
||||
unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
|
||||
unet_config["out_channels"] = sd_shape(state_dict, final_layer)[0] // (patch_size * patch_size)
|
||||
|
||||
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
|
||||
unet_config["depth"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0] // 64
|
||||
unet_config["input_size"] = None
|
||||
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
|
||||
if y_key in state_dict_keys:
|
||||
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
|
||||
unet_config["adm_in_channels"] = sd_shape(state_dict, y_key)[1]
|
||||
|
||||
context_key = '{}context_embedder.weight'.format(key_prefix)
|
||||
if context_key in state_dict_keys:
|
||||
in_features = state_dict[context_key].shape[1]
|
||||
out_features = state_dict[context_key].shape[0]
|
||||
in_features = sd_shape(state_dict, context_key)[1]
|
||||
out_features = sd_shape(state_dict, context_key)[0]
|
||||
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
|
||||
num_patches_key = '{}pos_embed'.format(key_prefix)
|
||||
if num_patches_key in state_dict_keys:
|
||||
num_patches = state_dict[num_patches_key].shape[1]
|
||||
num_patches = sd_shape(state_dict, num_patches_key)[1]
|
||||
unet_config["num_patches"] = num_patches
|
||||
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
|
||||
|
||||
@ -83,23 +86,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
||||
if text_mapper_name in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'c'
|
||||
w = state_dict[text_mapper_name]
|
||||
if w.shape[0] == 1536: #stage c lite
|
||||
w_shape = sd_shape(state_dict, text_mapper_name)
|
||||
if w_shape[0] == 1536: #stage c lite
|
||||
unet_config['c_cond'] = 1536
|
||||
unet_config['c_hidden'] = [1536, 1536]
|
||||
unet_config['nhead'] = [24, 24]
|
||||
unet_config['blocks'] = [[4, 12], [12, 4]]
|
||||
elif w.shape[0] == 2048: #stage c full
|
||||
elif w_shape[0] == 2048: #stage c full
|
||||
unet_config['c_cond'] = 2048
|
||||
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'b'
|
||||
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
|
||||
if w.shape[-1] == 640:
|
||||
w_shape = sd_shape(state_dict, '{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix))
|
||||
if w_shape[-1] == 640:
|
||||
unet_config['c_hidden'] = [320, 640, 1280, 1280]
|
||||
unet_config['nhead'] = [-1, -1, 20, 20]
|
||||
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
|
||||
elif w.shape[-1] == 576: #stage b lite
|
||||
elif w_shape[-1] == 576: #stage b lite
|
||||
unet_config['c_hidden'] = [320, 576, 1152, 1152]
|
||||
unet_config['nhead'] = [-1, 9, 18, 18]
|
||||
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
||||
@ -113,8 +116,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
|
||||
unet_config = {}
|
||||
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
|
||||
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
|
||||
unet_config["max_seq"] = sd_shape(state_dict, '{}positional_encoding'.format(key_prefix))[1]
|
||||
unet_config["cond_seq_dim"] = sd_shape(state_dict, '{}cond_seq_linear.weight'.format(key_prefix))[1]
|
||||
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
|
||||
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
|
||||
unet_config["n_double_layers"] = double_layers
|
||||
@ -125,10 +128,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "hydit"
|
||||
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
|
||||
unet_config["hidden_size"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0]
|
||||
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
|
||||
unet_config["mlp_ratio"] = 4.3637
|
||||
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
|
||||
if sd_shape(state_dict, '{}extra_embedder.0.weight'.format(key_prefix))[1] == 3968:
|
||||
unet_config["size_cond"] = True
|
||||
unet_config["use_style_cond"] = True
|
||||
unet_config["image_model"] = "hydit1"
|
||||
@ -136,12 +139,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
||||
dit_config = {}
|
||||
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
|
||||
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
|
||||
in_w_shape = sd_shape(state_dict, '{}img_in.proj.weight'.format(key_prefix))
|
||||
out_w_shape = sd_shape(state_dict, '{}final_layer.linear.weight'.format(key_prefix))
|
||||
dit_config["image_model"] = "hunyuan_video"
|
||||
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
|
||||
dit_config["patch_size"] = list(in_w.shape[2:])
|
||||
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
|
||||
dit_config["in_channels"] = in_w_shape[1] #SkyReels img2video has 32 input channels
|
||||
dit_config["patch_size"] = list(in_w_shape[2:])
|
||||
dit_config["out_channels"] = out_w_shape[0] // math.prod(dit_config["patch_size"])
|
||||
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
|
||||
dit_config["vec_in_dim"] = 768
|
||||
else:
|
||||
@ -157,10 +160,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["meanflow"] = False
|
||||
|
||||
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["hidden_size"] = in_w.shape[0]
|
||||
dit_config["context_in_dim"] = sd_shape(state_dict, '{}txt_in.input_embedder.weight'.format(key_prefix))[1]
|
||||
dit_config["hidden_size"] = in_w_shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = in_w.shape[0] // 128
|
||||
dit_config["num_heads"] = in_w_shape[0] // 128
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["theta"] = 256
|
||||
@ -179,7 +182,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["use_cond_type_embedding"] = False
|
||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["vision_in_dim"] = sd_shape(state_dict, '{}vision_in.proj.0.weight'.format(key_prefix))[0]
|
||||
dit_config["meanflow_sum"] = True
|
||||
else:
|
||||
dit_config["vision_in_dim"] = None
|
||||
@ -221,19 +224,19 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
w = state_dict[in_key]
|
||||
dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
w_shape = sd_shape(state_dict, in_key)
|
||||
dit_config["in_channels"] = w_shape[1] // (patch_size * patch_size)
|
||||
dit_config["hidden_size"] = w_shape[0]
|
||||
|
||||
txt_in_key = "{}txt_in.weight".format(key_prefix)
|
||||
if txt_in_key in state_dict_keys:
|
||||
w = state_dict[txt_in_key]
|
||||
dit_config["context_in_dim"] = w.shape[1]
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
w_shape = sd_shape(state_dict, txt_in_key)
|
||||
dit_config["context_in_dim"] = w_shape[1]
|
||||
dit_config["hidden_size"] = w_shape[0]
|
||||
|
||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||
if vec_in_key in state_dict_keys:
|
||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||
dit_config["vec_in_dim"] = sd_shape(state_dict, vec_in_key)[1]
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
|
||||
@ -307,7 +310,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||
shape = sd_shape(state_dict, '{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix))
|
||||
dit_config["attention_head_dim"] = shape[0] // 32
|
||||
dit_config["cross_attention_dim"] = shape[1]
|
||||
if metadata is not None and "config" in metadata:
|
||||
@ -350,11 +353,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
y_key = "{}y_embedder.y_embedding".format(key_prefix)
|
||||
if y_key in state_dict_keys:
|
||||
dit_config["model_max_length"] = state_dict[y_key].shape[0]
|
||||
dit_config["model_max_length"] = sd_shape(state_dict, y_key)[0]
|
||||
|
||||
pe_key = "{}pos_embed".format(key_prefix)
|
||||
if pe_key in state_dict_keys:
|
||||
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
|
||||
dit_config["input_size"] = int(math.sqrt(sd_shape(state_dict, pe_key)[1])) * patch_size
|
||||
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
|
||||
|
||||
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
|
||||
@ -373,11 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
concat_padding_mask = True
|
||||
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_spatial"] = 2
|
||||
dit_config["patch_temporal"] = 1
|
||||
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["model_channels"] = sd_shape(state_dict, '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix))[0]
|
||||
dit_config["block_config"] = "FA-CA-MLP"
|
||||
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
dit_config["pos_emb_cls"] = "rope3d"
|
||||
@ -416,9 +419,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["in_channels"] = 16
|
||||
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
|
||||
dit_config["dim"] = w.shape[0]
|
||||
dit_config["cap_feat_dim"] = w.shape[1]
|
||||
w_shape = sd_shape(state_dict, '{}cap_embedder.1.weight'.format(key_prefix))
|
||||
dit_config["dim"] = w_shape[0]
|
||||
dit_config["cap_feat_dim"] = w_shape[1]
|
||||
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
dit_config["qk_norm"] = True
|
||||
|
||||
@ -429,9 +432,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None: # NewBie
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
ctd_key = '{}clip_text_pooled_proj.0.weight'.format(key_prefix)
|
||||
if ctd_key in state_dict_keys: # NewBie
|
||||
dit_config["clip_text_dim"] = sd_shape(state_dict, ctd_key)[0]
|
||||
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
@ -450,12 +453,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
|
||||
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
|
||||
dim = sd_shape(state_dict, '{}head.modulation'.format(key_prefix))[-1]
|
||||
out_dim = sd_shape(state_dict, '{}head.head.weight'.format(key_prefix))[0] // 4
|
||||
dit_config["dim"] = dim
|
||||
dit_config["out_dim"] = out_dim
|
||||
dit_config["num_heads"] = dim // 128
|
||||
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["ffn_dim"] = sd_shape(state_dict, '{}blocks.0.ffn.0.weight'.format(key_prefix))[0]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["patch_size"] = (1, 2, 2)
|
||||
dit_config["freq_dim"] = 256
|
||||
@ -463,10 +466,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["cross_attn_norm"] = True
|
||||
dit_config["eps"] = 1e-6
|
||||
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["in_dim"] = sd_shape(state_dict, '{}patch_embedding.weight'.format(key_prefix))[1]
|
||||
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "vace"
|
||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["vace_in_dim"] = sd_shape(state_dict, '{}vace_patch_embedding.weight'.format(key_prefix))[1]
|
||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
@ -484,22 +487,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "i2v"
|
||||
else:
|
||||
dit_config["model_type"] = "t2v"
|
||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||
if flf_weight is not None:
|
||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||
flf_key = '{}img_emb.emb_pos'.format(key_prefix)
|
||||
if flf_key in state_dict_keys:
|
||||
dit_config["flf_pos_embed_token_number"] = sd_shape(state_dict, flf_key)[1]
|
||||
|
||||
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
|
||||
if ref_conv_weight is not None:
|
||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||
ref_conv_key = '{}ref_conv.weight'.format(key_prefix)
|
||||
if ref_conv_key in state_dict_keys:
|
||||
dit_config["in_dim_ref_conv"] = sd_shape(state_dict, ref_conv_key)[1]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
||||
in_shape = sd_shape(state_dict, '{}latent_in.weight'.format(key_prefix))
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2"
|
||||
dit_config["in_channels"] = in_shape[1]
|
||||
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["context_in_dim"] = sd_shape(state_dict, '{}cond_in.weight'.format(key_prefix))[1]
|
||||
dit_config["hidden_size"] = in_shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 16
|
||||
@ -513,9 +516,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2_1"
|
||||
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
||||
dit_config["in_channels"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[1]
|
||||
dit_config["context_dim"] = 1024
|
||||
dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
|
||||
dit_config["hidden_size"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 16
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
|
||||
@ -549,11 +552,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
concat_padding_mask = True
|
||||
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_spatial"] = 2
|
||||
dit_config["patch_temporal"] = 1
|
||||
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["model_channels"] = sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[0]
|
||||
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
dit_config["crossattn_emb_channels"] = 1024
|
||||
dit_config["pos_emb_cls"] = "rope3d"
|
||||
@ -617,7 +620,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["in_channels"] = sd_shape(state_dict, '{}img_in.weight'.format(key_prefix))[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||
@ -628,7 +631,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
dit_config = {}
|
||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
model_dim = sd_shape(state_dict, '{}visual_embeddings.in_layer.bias'.format(key_prefix))[0]
|
||||
dit_config["model_dim"] = model_dim
|
||||
if model_dim in [4096, 2560]: # pro video and lite image
|
||||
dit_config["axes_dims"] = (32, 48, 48)
|
||||
@ -636,10 +639,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
|
||||
elif model_dim == 1792: # lite video
|
||||
dit_config["axes_dims"] = (16, 24, 24)
|
||||
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["time_dim"] = sd_shape(state_dict, '{}time_embeddings.in_layer.bias'.format(key_prefix))[0]
|
||||
dit_config["image_model"] = "kandinsky5"
|
||||
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["ff_dim"] = sd_shape(state_dict, '{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix))[0]
|
||||
dit_config["visual_embed_dim"] = sd_shape(state_dict, '{}visual_embeddings.in_layer.weight'.format(key_prefix))[1]
|
||||
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
@ -657,16 +660,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
y_input = '{}label_emb.0.0.weight'.format(key_prefix)
|
||||
if y_input in state_dict_keys:
|
||||
unet_config["num_classes"] = "sequential"
|
||||
unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
|
||||
unet_config["adm_in_channels"] = sd_shape(state_dict, y_input)[1]
|
||||
else:
|
||||
unet_config["adm_in_channels"] = None
|
||||
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
model_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[0]
|
||||
in_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[1]
|
||||
|
||||
out_key = '{}out.2.weight'.format(key_prefix)
|
||||
if out_key in state_dict:
|
||||
out_channels = state_dict[out_key].shape[0]
|
||||
out_channels = sd_shape(state_dict, out_key)[0]
|
||||
else:
|
||||
out_channels = 4
|
||||
|
||||
@ -713,7 +716,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
|
||||
if res_block_prefix in block_keys:
|
||||
last_res_blocks += 1
|
||||
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
|
||||
last_channel_mult = sd_shape(state_dict, "{}0.out_layers.3.weight".format(prefix))[0] // model_channels
|
||||
|
||||
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
@ -867,7 +870,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
||||
transformer_depth.append(transformer_count)
|
||||
if transformer_count > 0:
|
||||
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
|
||||
match["context_dim"] = sd_shape(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab))[1]
|
||||
|
||||
attn_res *= 2
|
||||
if attn_blocks == 0:
|
||||
@ -876,13 +879,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
|
||||
match["transformer_depth"] = transformer_depth
|
||||
|
||||
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
||||
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
||||
match["model_channels"] = sd_shape(state_dict, "conv_in.weight")[0]
|
||||
match["in_channels"] = sd_shape(state_dict, "conv_in.weight")[1]
|
||||
match["adm_in_channels"] = None
|
||||
if "class_embedding.linear_1.weight" in state_dict:
|
||||
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
||||
match["adm_in_channels"] = sd_shape(state_dict, "class_embedding.linear_1.weight")[1]
|
||||
elif "add_embedding.linear_1.weight" in state_dict:
|
||||
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||
match["adm_in_channels"] = sd_shape(state_dict, "add_embedding.linear_1.weight")[1]
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
@ -1023,11 +1026,11 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
elif 'x_embedder.weight' in state_dict: #Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||
hidden_size = sd_shape(state_dict, "x_embedder.bias")[0]
|
||||
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
depth = sd_shape(state_dict, "pos_embed.proj.weight")[0] // 64
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -26,6 +26,7 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
import comfy.disk_weights
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -540,7 +541,12 @@ class LoadedModel:
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
offload_device = None
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
offload_device = torch.device("meta")
|
||||
self.model.detach(unpatch_weights, offload_device=offload_device)
|
||||
if offload_device is not None and offload_device.type == "meta":
|
||||
logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
@ -594,6 +600,11 @@ def minimum_inference_memory():
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
||||
logging.info("RAM pressure: requested %.2f MB, free %.2f MB", memory_required / (1024 * 1024), get_free_memory(device) / (1024 * 1024))
|
||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
if freed_cache < memory_required:
|
||||
evict_ram_to_disk(memory_required - freed_cache)
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -629,6 +640,34 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
|
||||
def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
|
||||
if memory_to_free <= 0:
|
||||
return 0
|
||||
if not comfy.disk_weights.disk_weights_enabled():
|
||||
return 0
|
||||
|
||||
freed = 0
|
||||
can_unload = []
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
loaded_memory = shift_model.model_loaded_memory()
|
||||
if loaded_memory > 0:
|
||||
can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_needed = memory_to_free - freed
|
||||
if memory_needed <= 0:
|
||||
break
|
||||
logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
|
||||
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
|
||||
|
||||
if freed > 0:
|
||||
logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
|
||||
return freed
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
@ -1135,6 +1174,16 @@ if not args.disable_pinned_memory:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
WEIGHTS_RAM_CACHE_BYTES = 0
|
||||
WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
|
||||
if args.weights_ram_cache_gb is not None:
|
||||
WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
|
||||
comfy.disk_weights.configure(
|
||||
WEIGHTS_RAM_CACHE_BYTES,
|
||||
allow_gds=WEIGHTS_GDS_ENABLED,
|
||||
pin_if_cpu=not args.disable_pinned_memory,
|
||||
)
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
|
||||
def discard_cuda_async_error():
|
||||
@ -1291,7 +1340,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
if hasattr(dev, 'type') and dev.type == "meta":
|
||||
mem_free_total = sys.maxsize
|
||||
mem_free_torch = mem_free_total
|
||||
elif hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
|
||||
@ -34,6 +34,7 @@ import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
import comfy.disk_weights
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
@ -269,6 +270,8 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
comfy.disk_weights.attach_disk_weight_hooks(self.model)
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@ -783,7 +786,7 @@ class ModelPatcher:
|
||||
m.comfy_patched_weights = True
|
||||
|
||||
for x in load_completely:
|
||||
x[2].to(device_to)
|
||||
comfy.disk_weights.module_to(x[2], device_to)
|
||||
|
||||
for x in offloaded:
|
||||
n = x[1]
|
||||
@ -799,7 +802,7 @@ class ModelPatcher:
|
||||
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
comfy.disk_weights.module_to(self.model, device_to)
|
||||
mem_counter = self.model_size()
|
||||
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
@ -856,7 +859,7 @@ class ModelPatcher:
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
comfy.disk_weights.module_to(self.model, device_to, allow_materialize=False)
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
@ -883,6 +886,9 @@ class ModelPatcher:
|
||||
if len(unload_list) > 0:
|
||||
NS = comfy.model_management.NUM_STREAMS
|
||||
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
|
||||
remaining_ram = None
|
||||
if device_to is not None and comfy.model_management.is_device_cpu(device_to):
|
||||
remaining_ram = comfy.model_management.get_free_memory(device_to)
|
||||
|
||||
for unload in unload_list:
|
||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||
@ -916,7 +922,24 @@ class ModelPatcher:
|
||||
bias_key = "{}.bias".format(n)
|
||||
if move_weight:
|
||||
cast_weight = self.force_cast_weights
|
||||
m.to(device_to)
|
||||
freed_bytes = module_mem
|
||||
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
|
||||
freed_bytes = comfy.disk_weights.offload_module_weights(m)
|
||||
if freed_bytes == 0:
|
||||
freed_bytes = module_mem
|
||||
else:
|
||||
if remaining_ram is not None and remaining_ram < module_mem and comfy.disk_weights.disk_weights_enabled():
|
||||
logging.info("Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", n, module_mem / (1024 * 1024), remaining_ram / (1024 * 1024))
|
||||
freed_bytes = comfy.disk_weights.offload_module_weights(m)
|
||||
if freed_bytes == 0:
|
||||
freed_bytes = module_mem
|
||||
else:
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
comfy.disk_weights.move_module_tensors(m, device_to)
|
||||
else:
|
||||
m.to(device_to)
|
||||
if remaining_ram is not None:
|
||||
remaining_ram = max(0, remaining_ram - module_mem)
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
@ -939,7 +962,7 @@ class ModelPatcher:
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
memory_freed += module_mem
|
||||
memory_freed += freed_bytes
|
||||
offload_buffer = max(offload_buffer, potential_offload)
|
||||
offload_weight_factor.append(module_mem)
|
||||
offload_weight_factor.pop(0)
|
||||
@ -953,7 +976,8 @@ class ModelPatcher:
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
target_label = "disk" if device_to is not None and device_to.type == "meta" else device_to
|
||||
logging.info("Unloaded partially to {}: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(target_label, memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
@ -984,11 +1008,12 @@ class ModelPatcher:
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def detach(self, unpatch_all=True):
|
||||
def detach(self, unpatch_all=True, offload_device=None):
|
||||
self.eject_model()
|
||||
self.model_patches_to(self.offload_device)
|
||||
target_device = self.offload_device if offload_device is None else offload_device
|
||||
if unpatch_all:
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
self.unpatch_model(target_device, unpatch_weights=unpatch_all)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
||||
callback(self, unpatch_all)
|
||||
return self.model
|
||||
@ -1358,4 +1383,3 @@ class ModelPatcher:
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
self.detach(unpatch_all=False)
|
||||
|
||||
|
||||
63
comfy/ops.py
63
comfy/ops.py
@ -19,6 +19,7 @@
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.disk_weights
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
@ -98,11 +99,35 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
weight_source = s.weight
|
||||
bias_source = s.bias
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
if weight_source.device.type == "meta":
|
||||
loaded = comfy.disk_weights.load_module_tensor(
|
||||
s,
|
||||
"weight",
|
||||
device,
|
||||
temporary=True,
|
||||
dtype_override=dtype,
|
||||
)
|
||||
if loaded is not None:
|
||||
weight_source = loaded
|
||||
if bias_source is not None and bias_source.device.type == "meta":
|
||||
loaded_bias = comfy.disk_weights.load_module_tensor(
|
||||
s,
|
||||
"bias",
|
||||
device,
|
||||
temporary=True,
|
||||
dtype_override=bias_dtype,
|
||||
)
|
||||
if loaded_bias is not None:
|
||||
bias_source = loaded_bias
|
||||
|
||||
weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
if bias_source is not None:
|
||||
bias = comfy.model_management.cast_to(bias_source, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
@ -532,9 +557,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
if value.device.type != "meta":
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
@ -551,11 +577,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
if layer_conf is not None and layer_conf.device.type != "meta":
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
elif layer_conf is not None:
|
||||
layer_conf = None
|
||||
|
||||
if layer_conf is None:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
if weight.device.type == "meta":
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
@ -601,10 +632,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
if weight.device.type == "meta":
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
@ -614,7 +648,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
if _v.device.type == "meta":
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False))
|
||||
else:
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
934
comfy/safetensors_stream.py
Normal file
934
comfy/safetensors_stream.py
Normal file
@ -0,0 +1,934 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import ctypes
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_FST_MODULE = None
|
||||
_FST_LOCK = threading.Lock()
|
||||
_FST_LOADED = False
|
||||
_GDS_INITIALIZED = False
|
||||
_MISSING = object()
|
||||
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
|
||||
|
||||
|
||||
def _require_fastsafetensors():
|
||||
global _FST_MODULE
|
||||
with _FST_LOCK:
|
||||
if _FST_MODULE is None:
|
||||
if importlib.util.find_spec("fastsafetensors") is None:
|
||||
raise ImportError(
|
||||
"fastsafetensors is required for safetensors streaming. "
|
||||
"Install it with: pip install 'fastsafetensors @ https://github.com/"
|
||||
"foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip'"
|
||||
)
|
||||
_FST_MODULE = importlib.import_module("fastsafetensors")
|
||||
return _FST_MODULE
|
||||
|
||||
|
||||
def _init_fastsafetensors_lib():
|
||||
global _FST_LOADED
|
||||
fst = _require_fastsafetensors()
|
||||
if not _FST_LOADED:
|
||||
fst.cpp.load_library_functions()
|
||||
_FST_LOADED = True
|
||||
return fst
|
||||
|
||||
|
||||
def _init_gds():
|
||||
global _GDS_INITIALIZED
|
||||
fst = _init_fastsafetensors_lib()
|
||||
if not _GDS_INITIALIZED:
|
||||
if fst.cpp.init_gds() != 0:
|
||||
raise RuntimeError("fastsafetensors init_gds() failed")
|
||||
_GDS_INITIALIZED = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorMeta:
|
||||
dtype: torch.dtype
|
||||
shape: Tuple[int, ...]
|
||||
numel: int
|
||||
nbytes: int
|
||||
data_offsets: Tuple[int, int]
|
||||
filename: str
|
||||
fst_dtype: object
|
||||
strides: Tuple[int, ...]
|
||||
|
||||
|
||||
class SafeTensorIndex:
|
||||
def __init__(self, filename: str):
|
||||
fst = _init_fastsafetensors_lib()
|
||||
framework = fst.frameworks.get_framework_op("pytorch")
|
||||
metadata = fst.common.SafeTensorsMetadata.from_file(filename, framework)
|
||||
self._filename = filename
|
||||
self._metadata = metadata
|
||||
self._framework = framework
|
||||
from fastsafetensors.frameworks import _torch as fst_torch
|
||||
self._dtype_map = fst_torch.dtype_convert
|
||||
self._tensor_meta: Dict[str, TensorMeta] = {}
|
||||
for key, frame in metadata.tensors.items():
|
||||
torch_dtype = self._dtype_map.get(frame.dtype, None)
|
||||
if torch_dtype is None:
|
||||
raise ValueError(f"Unsupported safetensors dtype {frame.dtype} in {filename}")
|
||||
numel = 1
|
||||
for s in frame.shape:
|
||||
numel *= s
|
||||
nbytes = numel * framework.get_dtype_size(frame.dtype)
|
||||
self._tensor_meta[key] = TensorMeta(
|
||||
dtype=torch_dtype,
|
||||
shape=tuple(frame.shape),
|
||||
numel=numel,
|
||||
nbytes=nbytes,
|
||||
data_offsets=(frame.data_offsets[0], frame.data_offsets[1]),
|
||||
filename=filename,
|
||||
fst_dtype=frame.dtype,
|
||||
strides=tuple(frame.strides),
|
||||
)
|
||||
|
||||
def keys(self) -> Iterable[str]:
|
||||
return self._tensor_meta.keys()
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
return key in self._tensor_meta
|
||||
|
||||
def meta(self, key: str) -> TensorMeta:
|
||||
return self._tensor_meta[key]
|
||||
|
||||
def metadata(self):
|
||||
return self._metadata.metadata
|
||||
|
||||
@property
|
||||
def header_length(self) -> int:
|
||||
return self._metadata.header_length
|
||||
|
||||
@property
|
||||
def size_bytes(self) -> int:
|
||||
return self._metadata.size_bytes
|
||||
|
||||
|
||||
class _SafeTensorFile:
|
||||
def __init__(self, filename: str, index: SafeTensorIndex):
|
||||
self.filename = filename
|
||||
self.index = index
|
||||
self._fd: Optional[int] = None
|
||||
self._gds_handle = None
|
||||
self._gds_reader = None
|
||||
self._nogds_reader = None
|
||||
self._refcount = 1
|
||||
|
||||
def acquire(self) -> "_SafeTensorFile":
|
||||
self._refcount += 1
|
||||
return self
|
||||
|
||||
def release(self):
|
||||
self._refcount -= 1
|
||||
if self._refcount <= 0:
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
if self._fd is not None:
|
||||
os.close(self._fd)
|
||||
self._fd = None
|
||||
self._gds_handle = None
|
||||
|
||||
def _ensure_fd(self) -> int:
|
||||
if self._fd is None:
|
||||
self._fd = os.open(self.filename, os.O_RDONLY, 0o644)
|
||||
return self._fd
|
||||
|
||||
def _ensure_nogds_reader(self, use_cuda: bool):
|
||||
fst = _init_fastsafetensors_lib()
|
||||
if self._nogds_reader is None:
|
||||
self._nogds_reader = fst.cpp.nogds_file_reader(
|
||||
False, 16 * 1024, 16, use_cuda
|
||||
)
|
||||
return self._nogds_reader
|
||||
|
||||
def _ensure_gds_reader(self, use_cuda: bool):
|
||||
fst = _init_fastsafetensors_lib()
|
||||
if self._gds_reader is None:
|
||||
self._gds_reader = fst.cpp.gds_file_reader(16, use_cuda)
|
||||
return self._gds_reader
|
||||
|
||||
def _ensure_gds_handle(self, use_cuda: bool):
|
||||
if self._gds_handle is None:
|
||||
fst = _init_fastsafetensors_lib()
|
||||
framework = fst.frameworks.get_framework_op("pytorch")
|
||||
o_direct = _get_gds_o_direct(framework)
|
||||
self._gds_handle = fst.cpp.gds_file_handle(self.filename, o_direct, use_cuda)
|
||||
return self._gds_handle
|
||||
|
||||
def read_tensor(
|
||||
self,
|
||||
meta: TensorMeta,
|
||||
device: torch.device,
|
||||
dtype: Optional[torch.dtype],
|
||||
allow_gds: bool,
|
||||
pin_if_cpu: bool,
|
||||
) -> torch.Tensor:
|
||||
fst = _init_fastsafetensors_lib()
|
||||
framework = fst.frameworks.get_framework_op("pytorch")
|
||||
device_is_cuda = device.type == "cuda"
|
||||
if device_is_cuda and allow_gds:
|
||||
_ensure_gds_ready(device)
|
||||
tensor = self._read_tensor_gds(
|
||||
fst, framework, meta, device, dtype
|
||||
)
|
||||
return tensor
|
||||
|
||||
cpu_tensor = self._read_tensor_nogds(
|
||||
fst, framework, meta, torch.device("cpu"), dtype
|
||||
)
|
||||
if device_is_cuda:
|
||||
if pin_if_cpu:
|
||||
cpu_tensor = cpu_tensor.pin_memory()
|
||||
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
|
||||
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
|
||||
return gpu_tensor
|
||||
return cpu_tensor
|
||||
|
||||
def _aligned_range(self, abs_start: int, length: int) -> Tuple[int, int, int]:
|
||||
fst = _init_fastsafetensors_lib()
|
||||
align = fst.cpp.get_alignment_size()
|
||||
aligned_offset = (abs_start // align) * align
|
||||
head = abs_start - aligned_offset
|
||||
aligned_length = length + head
|
||||
tail = aligned_length % align
|
||||
if tail:
|
||||
aligned_length += align - tail
|
||||
return aligned_offset, aligned_length, head
|
||||
|
||||
def _read_tensor_nogds(
|
||||
self,
|
||||
fst,
|
||||
framework,
|
||||
meta: TensorMeta,
|
||||
device: torch.device,
|
||||
dtype: Optional[torch.dtype],
|
||||
) -> torch.Tensor:
|
||||
fd = self._ensure_fd()
|
||||
reader = self._ensure_nogds_reader(use_cuda=False)
|
||||
abs_start = self.index.header_length + meta.data_offsets[0]
|
||||
length = meta.data_offsets[1] - meta.data_offsets[0]
|
||||
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
|
||||
chunk_bytes = max(1, chunk_bytes)
|
||||
ptr_align = framework.get_device_ptr_align()
|
||||
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu")
|
||||
buffer_length = 0
|
||||
buf_ptr = None
|
||||
gbuf = None
|
||||
try:
|
||||
chunk_offset = 0
|
||||
while chunk_offset < length:
|
||||
chunk_len = min(length - chunk_offset, chunk_bytes)
|
||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
|
||||
needed = aligned_length + ptr_align
|
||||
if buf_ptr is None or needed > buffer_length:
|
||||
if buf_ptr is not None:
|
||||
fst.cpp.cpu_free(buf_ptr)
|
||||
buffer_length = needed
|
||||
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
|
||||
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
|
||||
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
|
||||
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
|
||||
if reader.wait_read(req) < 0:
|
||||
raise RuntimeError("nogds_file_reader read failed")
|
||||
src_ptr = gbuf.get_base_address() + ptr_off + head
|
||||
dest_ptr = dest_tensor.data_ptr() + chunk_offset
|
||||
ctypes.memmove(dest_ptr, src_ptr, chunk_len)
|
||||
chunk_offset += chunk_len
|
||||
except Exception:
|
||||
if buf_ptr is not None:
|
||||
fst.cpp.cpu_free(buf_ptr)
|
||||
raise
|
||||
if buf_ptr is not None:
|
||||
fst.cpp.cpu_free(buf_ptr)
|
||||
if dtype is not None and dtype != dest_tensor.dtype:
|
||||
_validate_dtype_conversion(dest_tensor.dtype, dtype)
|
||||
dest_tensor = dest_tensor.to(dtype=dtype)
|
||||
return dest_tensor
|
||||
|
||||
def _read_tensor_gds(
|
||||
self,
|
||||
fst,
|
||||
framework,
|
||||
meta: TensorMeta,
|
||||
device: torch.device,
|
||||
dtype: Optional[torch.dtype],
|
||||
) -> torch.Tensor:
|
||||
reader = self._ensure_gds_reader(use_cuda=True)
|
||||
handle = self._ensure_gds_handle(use_cuda=True)
|
||||
abs_start = self.index.header_length + meta.data_offsets[0]
|
||||
length = meta.data_offsets[1] - meta.data_offsets[0]
|
||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
|
||||
ptr_align = framework.get_device_ptr_align()
|
||||
buffer_length = aligned_length + ptr_align
|
||||
fst_device = _fst_device_from_torch(fst, device)
|
||||
gbuf = framework.alloc_tensor_memory(buffer_length, fst_device)
|
||||
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
|
||||
file_length = self.index.size_bytes
|
||||
req = reader.submit_read(
|
||||
handle, gbuf, aligned_offset, aligned_length, ptr_off, file_length
|
||||
)
|
||||
if reader.wait_read(req) < 0:
|
||||
framework.free_tensor_memory(gbuf, fst_device)
|
||||
raise RuntimeError("gds_file_reader read failed")
|
||||
owner = _BufferOwner(lambda: framework.free_tensor_memory(gbuf, fst_device))
|
||||
tensor = _dlpack_tensor_from_buffer(
|
||||
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
|
||||
)
|
||||
if dtype is not None and dtype != tensor.dtype:
|
||||
_validate_dtype_conversion(tensor.dtype, dtype)
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
def _fst_device_from_torch(fst, device: torch.device):
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
return fst.st_types.Device.from_str(f"cuda:{device.index}")
|
||||
return fst.st_types.Device.from_str(device.type)
|
||||
|
||||
|
||||
class _BufferOwner:
|
||||
def __init__(self, free_fn):
|
||||
self._free_fn = free_fn
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self._free_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _dlpack_tensor_from_buffer(
|
||||
fst,
|
||||
framework,
|
||||
ptr: int,
|
||||
meta: TensorMeta,
|
||||
device: torch.device,
|
||||
owner: Optional[_BufferOwner],
|
||||
) -> torch.Tensor:
|
||||
disk_dtype = framework.as_workaround_dtype(meta.fst_dtype)
|
||||
dev = _fst_device_from_torch(fst, device)
|
||||
dl_tensor = fst.dlpack.from_cuda_buffer(ptr, list(meta.shape), list(meta.strides), disk_dtype, dev)
|
||||
torch_tensor = framework.from_dlpack(dl_tensor, dev, disk_dtype).real_tensor
|
||||
if disk_dtype != meta.fst_dtype:
|
||||
torch_tensor = torch_tensor.view(meta.dtype)
|
||||
if owner is not None:
|
||||
torch_tensor._comfy_disk_buffer_owner = owner
|
||||
return torch_tensor
|
||||
|
||||
|
||||
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
|
||||
if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
|
||||
raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
|
||||
|
||||
|
||||
def _get_gds_o_direct(framework) -> bool:
|
||||
cuda_ver = framework.get_cuda_ver()
|
||||
if cuda_ver and cuda_ver != "0.0":
|
||||
ver_parts = cuda_ver.split("-", 1)
|
||||
if len(ver_parts) == 2:
|
||||
cudavers = list(map(int, ver_parts[1].split(".")))
|
||||
if ver_parts[0] == "cuda":
|
||||
return not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2))
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def _ensure_gds_ready(device: torch.device):
|
||||
fst = _init_fastsafetensors_lib()
|
||||
if not fst.common.is_gpu_found():
|
||||
raise RuntimeError(
|
||||
"GPUDirect requested but GPU runtime library is missing. "
|
||||
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
||||
)
|
||||
gds_supported = fst.cpp.is_gds_supported(device.index if device.index is not None else 0)
|
||||
if gds_supported < 0:
|
||||
raise RuntimeError(
|
||||
"GPUDirect requested but is_gds_supported() failed. "
|
||||
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
||||
)
|
||||
if not fst.cpp.is_cufile_found():
|
||||
raise RuntimeError(
|
||||
"GPUDirect requested but libcufile is missing. "
|
||||
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
||||
)
|
||||
if gds_supported == 0:
|
||||
raise RuntimeError(
|
||||
"GPUDirect requested but GDS is unsupported on this platform. "
|
||||
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
||||
)
|
||||
_init_gds()
|
||||
|
||||
|
||||
class StreamStateDict(collections.abc.MutableMapping):
|
||||
is_stream_state_dict = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: SafeTensorIndex,
|
||||
file: _SafeTensorFile,
|
||||
device: torch.device,
|
||||
allow_gds: bool = False,
|
||||
):
|
||||
self._index = index
|
||||
self._file = file
|
||||
self._device = device
|
||||
self._allow_gds = allow_gds
|
||||
self._overrides: Dict[str, torch.Tensor] = {}
|
||||
self._deleted: set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename: str, device: torch.device, allow_gds: bool = False) -> "StreamStateDict":
|
||||
index = SafeTensorIndex(filename)
|
||||
file = _SafeTensorFile(filename, index)
|
||||
return cls(index, file, device, allow_gds=allow_gds)
|
||||
|
||||
def close(self):
|
||||
if self._file is not None:
|
||||
self._file.release()
|
||||
self._file = None
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def meta(self, key: str) -> TensorMeta:
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
numel = t.numel()
|
||||
return TensorMeta(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
data_offsets=(0, numel * t.element_size()),
|
||||
filename="<override>",
|
||||
fst_dtype=None,
|
||||
strides=tuple(t.stride()),
|
||||
)
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
return self._index.meta(key)
|
||||
|
||||
def get_tensor(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
allow_gds: Optional[bool] = None,
|
||||
pin_if_cpu: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
if device is not None and t.device != device:
|
||||
t = t.to(device=device)
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
_validate_dtype_conversion(t.dtype, dtype)
|
||||
t = t.to(dtype=dtype)
|
||||
return t
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if device is None:
|
||||
device = self._device
|
||||
if device.type == "meta":
|
||||
meta = self._index.meta(key)
|
||||
target_dtype = dtype or meta.dtype
|
||||
if dtype is not None and dtype != meta.dtype:
|
||||
_validate_dtype_conversion(meta.dtype, dtype)
|
||||
return torch.empty(meta.shape, dtype=target_dtype, device="meta")
|
||||
if allow_gds is None:
|
||||
allow_gds = self._allow_gds
|
||||
meta = self._index.meta(key)
|
||||
return self._file.read_tensor(meta, device, dtype, allow_gds, pin_if_cpu)
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
return self.get_tensor(key)
|
||||
|
||||
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
||||
self._overrides[key] = value
|
||||
self._deleted.discard(key)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if self._index.has(key):
|
||||
self._deleted.add(key)
|
||||
return
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for k in self._index.keys():
|
||||
if k in self._deleted:
|
||||
continue
|
||||
if k in self._overrides:
|
||||
continue
|
||||
yield k
|
||||
for k in self._overrides.keys():
|
||||
yield k
|
||||
|
||||
def __len__(self) -> int:
|
||||
base = len(self._index.keys())
|
||||
return base - len(self._deleted) + len(self._overrides)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
if not isinstance(key, str):
|
||||
return False
|
||||
if key in self._deleted:
|
||||
return False
|
||||
if key in self._overrides:
|
||||
return True
|
||||
return self._index.has(key)
|
||||
|
||||
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
return self._overrides.pop(key)
|
||||
if key in self._deleted:
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
if self._index.has(key):
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
|
||||
def copy(self) -> "StreamStateDict":
|
||||
new = StreamStateDict(self._index, self._file.acquire(), self._device, allow_gds=self._allow_gds)
|
||||
new._overrides = dict(self._overrides)
|
||||
new._deleted = set(self._deleted)
|
||||
return new
|
||||
|
||||
def metadata(self):
|
||||
return self._index.metadata()
|
||||
|
||||
|
||||
class _BaseViewStateDict(MutableMapping):
|
||||
is_stream_state_dict = True
|
||||
|
||||
def __init__(self, base: MutableMapping, mutate_base: bool = False):
|
||||
self._base = base
|
||||
self._mutate_base = mutate_base
|
||||
self._overrides: Dict[str, torch.Tensor] = {}
|
||||
self._deleted: set[str] = set()
|
||||
|
||||
def _resolve_base_key(self, key: str) -> Optional[str]:
|
||||
return key
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
return self._base.keys()
|
||||
|
||||
def get_tensor(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
allow_gds: Optional[bool] = None,
|
||||
pin_if_cpu: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
if device is not None and t.device != device:
|
||||
t = t.to(device=device)
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
_validate_dtype_conversion(t.dtype, dtype)
|
||||
t = t.to(dtype=dtype)
|
||||
return t
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if hasattr(self._base, "get_tensor"):
|
||||
return self._base.get_tensor(
|
||||
base_key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
|
||||
)
|
||||
t = self._base[base_key]
|
||||
if device is not None and t.device != device:
|
||||
t = t.to(device=device)
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
_validate_dtype_conversion(t.dtype, dtype)
|
||||
t = t.to(dtype=dtype)
|
||||
return t
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
return self.get_tensor(key)
|
||||
|
||||
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
||||
base_key = self._resolve_base_key(key)
|
||||
if self._mutate_base and base_key is not None and base_key in self._base:
|
||||
self._base[base_key] = value
|
||||
else:
|
||||
self._overrides[key] = value
|
||||
self._deleted.discard(key)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if self._mutate_base and base_key in self._base:
|
||||
del self._base[base_key]
|
||||
else:
|
||||
self._deleted.add(key)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for k in self._iter_base_keys():
|
||||
if k in self._deleted:
|
||||
continue
|
||||
yield k
|
||||
for k in self._overrides.keys():
|
||||
yield k
|
||||
|
||||
def __len__(self) -> int:
|
||||
base_keys = list(self._iter_base_keys())
|
||||
return len(base_keys) - len(self._deleted) + len(self._overrides)
|
||||
|
||||
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
return self._overrides.pop(key)
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
if self._mutate_base:
|
||||
try:
|
||||
return self._base.pop(base_key)
|
||||
except KeyError:
|
||||
if default is _MISSING:
|
||||
raise
|
||||
return default
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
|
||||
def meta(self, key: str):
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if hasattr(self._base, "meta"):
|
||||
return self._base.meta(base_key)
|
||||
t = self._base[base_key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
|
||||
|
||||
class DeviceViewStateDict(_BaseViewStateDict):
|
||||
def __init__(
|
||||
self,
|
||||
base: MutableMapping,
|
||||
device: torch.device,
|
||||
allow_gds: Optional[bool] = None,
|
||||
pin_if_cpu: bool = False,
|
||||
mutate_base: bool = False,
|
||||
):
|
||||
super().__init__(base, mutate_base=mutate_base)
|
||||
self._device = device
|
||||
self._allow_gds = allow_gds
|
||||
self._pin_if_cpu = pin_if_cpu
|
||||
|
||||
def get_tensor(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
allow_gds: Optional[bool] = None,
|
||||
pin_if_cpu: bool = False,
|
||||
) -> torch.Tensor:
|
||||
device = self._device if device is None else device
|
||||
allow_gds = self._allow_gds if allow_gds is None else allow_gds
|
||||
pin_if_cpu = self._pin_if_cpu if not pin_if_cpu else pin_if_cpu
|
||||
return super().get_tensor(
|
||||
key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
|
||||
)
|
||||
|
||||
def meta(self, key: str):
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if hasattr(self._base, "meta"):
|
||||
return self._base.meta(base_key)
|
||||
t = self._base[base_key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
return self.get_tensor(key)
|
||||
|
||||
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
||||
base_key = self._resolve_base_key(key)
|
||||
if self._mutate_base and base_key is not None and base_key in self._base:
|
||||
self._base[base_key] = value
|
||||
else:
|
||||
self._overrides[key] = value
|
||||
self._deleted.discard(key)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if self._mutate_base and base_key in self._base:
|
||||
del self._base[base_key]
|
||||
else:
|
||||
self._deleted.add(key)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for k in self._iter_base_keys():
|
||||
if k in self._deleted:
|
||||
continue
|
||||
yield k
|
||||
for k in self._overrides.keys():
|
||||
yield k
|
||||
|
||||
def __len__(self) -> int:
|
||||
base_keys = list(self._iter_base_keys())
|
||||
return len(base_keys) - len(self._deleted) + len(self._overrides)
|
||||
|
||||
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
return self._overrides.pop(key)
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
if self._mutate_base:
|
||||
try:
|
||||
return self._base.pop(base_key)
|
||||
except KeyError:
|
||||
if default is _MISSING:
|
||||
raise
|
||||
return default
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
|
||||
|
||||
class FilterViewStateDict(_BaseViewStateDict):
|
||||
def __init__(self, base: MutableMapping, predicate, mutate_base: bool = False):
|
||||
super().__init__(base, mutate_base=mutate_base)
|
||||
self._predicate = predicate
|
||||
|
||||
def _resolve_base_key(self, key: str) -> Optional[str]:
|
||||
if self._predicate(key):
|
||||
return key
|
||||
return None
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
for k in self._base.keys():
|
||||
if self._predicate(k):
|
||||
yield k
|
||||
|
||||
|
||||
class PrefixViewStateDict(_BaseViewStateDict):
|
||||
def __init__(self, base: MutableMapping, source_prefix: str, target_prefix: str = "", mutate_base: bool = False):
|
||||
super().__init__(base, mutate_base=mutate_base)
|
||||
self._source_prefix = source_prefix
|
||||
self._target_prefix = target_prefix
|
||||
self._mapping: Dict[str, str] = {}
|
||||
self._reverse: Dict[str, str] = {}
|
||||
for k in base.keys():
|
||||
if not k.startswith(source_prefix):
|
||||
continue
|
||||
view_key = f"{target_prefix}{k[len(source_prefix):]}"
|
||||
self._mapping[k] = view_key
|
||||
self._reverse[view_key] = k
|
||||
|
||||
def _resolve_base_key(self, key: str) -> Optional[str]:
|
||||
return self._reverse.get(key)
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
return self._reverse.keys()
|
||||
|
||||
|
||||
class RenameViewStateDict(_BaseViewStateDict):
|
||||
def __init__(
|
||||
self,
|
||||
base: MutableMapping,
|
||||
replace_prefix: Mapping[str, str],
|
||||
filter_keys: bool = False,
|
||||
mutate_base: bool = False,
|
||||
):
|
||||
super().__init__(base, mutate_base=mutate_base)
|
||||
self._filter_keys = filter_keys
|
||||
self._replace = list(replace_prefix.items())
|
||||
self._mapping: Dict[str, str] = {}
|
||||
self._reverse: Dict[str, str] = {}
|
||||
for k in base.keys():
|
||||
view_key = self._replace_key(k)
|
||||
if view_key is None:
|
||||
continue
|
||||
self._mapping[k] = view_key
|
||||
self._reverse[view_key] = k
|
||||
|
||||
def _replace_key(self, key: str) -> Optional[str]:
|
||||
for rp, dst in self._replace:
|
||||
if key.startswith(rp):
|
||||
return f"{dst}{key[len(rp):]}"
|
||||
if self._filter_keys:
|
||||
return None
|
||||
return key
|
||||
|
||||
def _resolve_base_key(self, key: str) -> Optional[str]:
|
||||
return self._reverse.get(key)
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
return self._reverse.keys()
|
||||
|
||||
|
||||
class MergedStateDict(MutableMapping):
|
||||
is_stream_state_dict = True
|
||||
|
||||
def __init__(self, *mappings: MutableMapping):
|
||||
self._mappings = list(mappings)
|
||||
self._overrides: Dict[str, torch.Tensor] = {}
|
||||
self._deleted: set[str] = set()
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
return self._overrides[key]
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
for mapping in reversed(self._mappings):
|
||||
if key in mapping:
|
||||
if hasattr(mapping, "get_tensor"):
|
||||
return mapping.get_tensor(key)
|
||||
return mapping[key]
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
||||
self._overrides[key] = value
|
||||
self._deleted.discard(key)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if any(key in mapping for mapping in self._mappings):
|
||||
self._deleted.add(key)
|
||||
return
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
seen = set()
|
||||
for mapping in self._mappings:
|
||||
for key in mapping.keys():
|
||||
if key in self._deleted or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
yield key
|
||||
for key in self._overrides.keys():
|
||||
if key not in seen:
|
||||
yield key
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(list(self.__iter__()))
|
||||
|
||||
def meta(self, key: str):
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
for mapping in reversed(self._mappings):
|
||||
if key in mapping:
|
||||
if hasattr(mapping, "meta"):
|
||||
return mapping.meta(key)
|
||||
t = mapping[key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
raise KeyError(key)
|
||||
|
||||
|
||||
class MappedStateDict(_BaseViewStateDict):
|
||||
def __init__(self, base: MutableMapping, key_map: Mapping[str, str], mutate_base: bool = False):
|
||||
super().__init__(base, mutate_base=mutate_base)
|
||||
self._base_to_view = dict(key_map)
|
||||
self._view_to_base = {v: k for k, v in key_map.items()}
|
||||
|
||||
def _resolve_base_key(self, key: str) -> Optional[str]:
|
||||
return self._view_to_base.get(key)
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
return self._view_to_base.keys()
|
||||
136
comfy/sd.py
136
comfy/sd.py
@ -25,6 +25,8 @@ import math
|
||||
import os
|
||||
|
||||
import comfy.utils
|
||||
import comfy.safetensors_stream
|
||||
import comfy.disk_weights
|
||||
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
@ -124,7 +126,7 @@ class CLIP:
|
||||
if not model_management.supports_cast(load_device, dt):
|
||||
load_device = offload_device
|
||||
if params['device'] != offload_device:
|
||||
self.cond_stage_model.to(offload_device)
|
||||
comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
@ -288,7 +290,7 @@ class CLIP:
|
||||
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
|
||||
else:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
@ -349,7 +351,7 @@ class VAE:
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||
elif "taesd_decoder.1.weight" in sd:
|
||||
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||
self.latent_channels = sd_shape(sd, "taesd_decoder.1.weight")[1]
|
||||
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||
self.first_stage_model = StageA()
|
||||
@ -364,25 +366,19 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["encoder.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
|
||||
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["previewer.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "previewer."})
|
||||
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
if sd_shape(sd, 'decoder.conv_in.weight')[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
|
||||
self.downscale_ratio = 32
|
||||
self.upscale_ratio = 32
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
@ -392,9 +388,9 @@ class VAE:
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||
elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
|
||||
elif sd_shape(sd, 'decoder.conv_in.weight')[1] == 32 and len(sd_shape(sd, 'decoder.conv_in.weight')) == 5:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
@ -417,7 +413,7 @@ class VAE:
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
|
||||
if 'decoder.post_quant_conv.weight' in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
|
||||
|
||||
@ -430,7 +426,7 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
|
||||
|
||||
if 'post_quant_conv.weight' in sd:
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
|
||||
else:
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||
@ -465,11 +461,11 @@ class VAE:
|
||||
self.downscale_index_formula = (6, 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
|
||||
tensor_conv1_shape = sd_shape(sd, "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight")
|
||||
version = 0
|
||||
if tensor_conv1.shape[0] == 512:
|
||||
if tensor_conv1_shape[0] == 512:
|
||||
version = 0
|
||||
elif tensor_conv1.shape[0] == 1024:
|
||||
elif tensor_conv1_shape[0] == 1024:
|
||||
version = 1
|
||||
if "encoder.down_blocks.1.conv.conv.bias" in sd:
|
||||
version = 2
|
||||
@ -486,9 +482,9 @@ class VAE:
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
self.downscale_index_formula = (8, 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd_shape(sd, 'decoder.conv_in.conv.weight')[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
|
||||
self.latent_channels = 32
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
@ -512,8 +508,8 @@ class VAE:
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
|
||||
#This is likely to significantly over-estimate with single image or low frame counts as the
|
||||
#implementation is able to completely skip caching. Rework if used as an image only VAE
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
|
||||
@ -546,14 +542,14 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
|
||||
else: # Wan 2.1 VAE
|
||||
dim = sd["decoder.head.0.gamma"].shape[0]
|
||||
dim = sd_shape(sd, "decoder.head.0.gamma")[0]
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||
self.output_channels = sd_shape(sd, "encoder.conv1.weight")[1]
|
||||
self.pad_channel_value = 1.0
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
@ -629,7 +625,7 @@ class VAE:
|
||||
self.working_dtypes = [torch.float32]
|
||||
self.crop_input = False
|
||||
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
|
||||
self.latent_channels = sd["decoder.1.weight"].shape[1]
|
||||
self.latent_channels = sd_shape(sd, "decoder.1.weight")[1]
|
||||
self.latent_dim = 3
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
@ -640,12 +636,12 @@ class VAE:
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.process_output = lambda image: image
|
||||
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
|
||||
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
|
||||
elif self.latent_channels == 32 and sd_shape(sd, "decoder.22.bias")[0] == 12: # lighttae_hv15
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
if comfy.utils.state_dict_meta(sd, "decoder.1.weight").dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
latent_format=comfy.latent_formats.HunyuanVideo
|
||||
else:
|
||||
latent_format=None # lighttaew2_1 doesn't need scaling
|
||||
@ -665,7 +661,7 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
@ -679,7 +675,7 @@ class VAE:
|
||||
if dtype is None:
|
||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||
self.vae_dtype = dtype
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
@ -986,9 +982,12 @@ def load_style_model(ckpt_path):
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
model.load_state_dict(model_data)
|
||||
comfy.utils.load_state_dict(model, model_data, strict=True)
|
||||
return StyleModel(model)
|
||||
|
||||
def sd_shape(state_dict, key):
|
||||
return comfy.utils.state_dict_meta(state_dict, key).shape
|
||||
|
||||
class CLIPType(Enum):
|
||||
STABLE_DIFFUSION = 1
|
||||
STABLE_CASCADE = 2
|
||||
@ -1058,16 +1057,16 @@ def detect_te_model(sd):
|
||||
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
|
||||
return TEModel.JINA_CLIP_2
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
weight_shape = sd_shape(sd, "encoder.block.23.layer.1.DenseReluDense.wi_1.weight")
|
||||
if weight_shape[-1] == 4096:
|
||||
return TEModel.T5_XXL
|
||||
elif weight.shape[-1] == 2048:
|
||||
elif weight_shape[-1] == 2048:
|
||||
return TEModel.T5_XL
|
||||
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||
return TEModel.T5_XXL_OLD
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
|
||||
if weight.shape[0] == 384:
|
||||
weight_shape = sd_shape(sd, 'encoder.block.0.layer.0.SelfAttention.k.weight')
|
||||
if weight_shape[0] == 384:
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
@ -1077,19 +1076,19 @@ def detect_te_model(sd):
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
if weight.shape[0] == 256:
|
||||
weight_shape = sd_shape(sd, 'model.layers.0.self_attn.k_proj.bias')
|
||||
if weight_shape[0] == 256:
|
||||
return TEModel.QWEN25_3B
|
||||
if weight.shape[0] == 512:
|
||||
if weight_shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
weight_shape = sd_shape(sd, 'model.layers.0.post_attention_layernorm.weight')
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
if weight.shape[0] == 2560:
|
||||
if weight_shape[0] == 2560:
|
||||
return TEModel.QWEN3_4B
|
||||
elif weight.shape[0] == 2048:
|
||||
elif weight_shape[0] == 2048:
|
||||
return TEModel.QWEN3_2B
|
||||
if weight.shape[0] == 5120:
|
||||
if weight_shape[0] == 5120:
|
||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.MISTRAL3_24B
|
||||
else:
|
||||
@ -1418,19 +1417,29 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
scaled_fp8_list.append(k[:-len("scaled_fp8")])
|
||||
|
||||
if len(scaled_fp8_list) > 0:
|
||||
out_sd = {}
|
||||
for k in sd:
|
||||
skip = False
|
||||
if comfy.utils.is_stream_state_dict(sd):
|
||||
def _keep_key(k, prefixes=tuple(scaled_fp8_list)):
|
||||
return not any(k.startswith(pref) for pref in prefixes)
|
||||
out_sd = comfy.safetensors_stream.FilterViewStateDict(sd, _keep_key, mutate_base=False)
|
||||
merged = out_sd
|
||||
for pref in scaled_fp8_list:
|
||||
skip = skip or k.startswith(pref)
|
||||
if not skip:
|
||||
out_sd[k] = sd[k]
|
||||
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||
merged = comfy.safetensors_stream.MergedStateDict(merged, quant_sd)
|
||||
sd = merged
|
||||
else:
|
||||
out_sd = {}
|
||||
for k in sd:
|
||||
skip = False
|
||||
for pref in scaled_fp8_list:
|
||||
skip = skip or k.startswith(pref)
|
||||
if not skip:
|
||||
out_sd[k] = sd[k]
|
||||
|
||||
for pref in scaled_fp8_list:
|
||||
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||
for k in quant_sd:
|
||||
out_sd[k] = quant_sd[k]
|
||||
sd = out_sd
|
||||
for pref in scaled_fp8_list:
|
||||
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||
for k in quant_sd:
|
||||
out_sd[k] = quant_sd[k]
|
||||
sd = out_sd
|
||||
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
if clip_target is not None:
|
||||
@ -1508,12 +1517,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
|
||||
|
||||
new_sd = {}
|
||||
for k in diffusers_keys:
|
||||
if k in sd:
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
if comfy.utils.is_stream_state_dict(sd):
|
||||
new_sd = comfy.safetensors_stream.MappedStateDict(sd, diffusers_keys)
|
||||
else:
|
||||
new_sd = {}
|
||||
for k in diffusers_keys:
|
||||
if k in sd:
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
@ -1538,7 +1550,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model = comfy.disk_weights.module_to(model, offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
|
||||
@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
return self(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
@ -430,8 +430,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
|
||||
try:
|
||||
if embed_path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||
embed = comfy.utils.load_torch_file(embed_path, safe_load=True)
|
||||
else:
|
||||
try:
|
||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||
|
||||
@ -56,9 +56,9 @@ class TAESD(nn.Module):
|
||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||
if encoder_path is not None:
|
||||
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
|
||||
if decoder_path is not None:
|
||||
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
|
||||
@ -119,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
return self.load_state_dict(sdo, strict=False)
|
||||
return comfy.utils.load_state_dict(self, sdo, strict=False)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
constant = 6.0
|
||||
|
||||
202
comfy/utils.py
202
comfy/utils.py
@ -26,10 +26,13 @@ import numpy as np
|
||||
from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
from types import SimpleNamespace
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
from . import safetensors_stream
|
||||
import comfy.disk_weights
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@ -61,15 +64,9 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
metadata = None
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
sd = safetensors_stream.StreamStateDict.from_file(ckpt, device=device)
|
||||
if return_metadata:
|
||||
metadata = sd.metadata()
|
||||
except Exception as e:
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
@ -110,16 +107,16 @@ def calculate_parameters(sd, prefix=""):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
params += w.nelement()
|
||||
meta = state_dict_meta(sd, k)
|
||||
params += meta.numel
|
||||
return params
|
||||
|
||||
def weight_dtype(sd, prefix=""):
|
||||
dtypes = {}
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
||||
meta = state_dict_meta(sd, k)
|
||||
dtypes[meta.dtype] = dtypes.get(meta.dtype, 0) + meta.numel
|
||||
|
||||
if len(dtypes) == 0:
|
||||
return None
|
||||
@ -133,6 +130,13 @@ def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
return state_dict
|
||||
|
||||
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||
if is_stream_state_dict(state_dict):
|
||||
return safetensors_stream.RenameViewStateDict(
|
||||
state_dict,
|
||||
replace_prefix,
|
||||
filter_keys=filter_keys,
|
||||
mutate_base=not filter_keys,
|
||||
)
|
||||
if filter_keys:
|
||||
out = {}
|
||||
else:
|
||||
@ -145,6 +149,79 @@ def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||
return out
|
||||
|
||||
|
||||
def is_stream_state_dict(state_dict) -> bool:
|
||||
return getattr(state_dict, "is_stream_state_dict", False)
|
||||
|
||||
|
||||
def state_dict_meta(state_dict, key):
|
||||
if hasattr(state_dict, "meta"):
|
||||
return state_dict.meta(key)
|
||||
w = state_dict[key]
|
||||
numel = w.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=w.dtype,
|
||||
shape=tuple(w.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * w.element_size(),
|
||||
)
|
||||
|
||||
|
||||
def load_state_dict(model, state_dict, strict=False, assign=False):
|
||||
if is_stream_state_dict(state_dict):
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
|
||||
comfy.disk_weights.register_module_weights(model, state_dict)
|
||||
comfy.disk_weights.attach_disk_weight_hooks(model)
|
||||
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
|
||||
return missing, unexpected
|
||||
return model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def stream_load_state_dict(model, state_dict, strict=False, assign=False):
|
||||
if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
|
||||
state_dict = state_dict.copy()
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
|
||||
def load(module, local_state_dict, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
if assign:
|
||||
local_metadata["assign_to_params_buffers"] = assign
|
||||
module._load_from_state_dict(
|
||||
local_state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
True,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
child_prefix = f"{prefix}{name}."
|
||||
child_state_dict = safetensors_stream.FilterViewStateDict(
|
||||
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
|
||||
)
|
||||
load(child, child_state_dict, child_prefix)
|
||||
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
for hook in module._load_state_dict_post_hooks.values():
|
||||
out = hook(module, incompatible)
|
||||
if out is not None:
|
||||
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
|
||||
|
||||
load(model, state_dict)
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
return missing_keys, unexpected_keys
|
||||
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
@ -825,7 +902,10 @@ def copy_to_param(obj, attr, value):
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
prev.data.copy_(value)
|
||||
if prev.device.type == "meta":
|
||||
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=prev.requires_grad))
|
||||
else:
|
||||
prev.data.copy_(value)
|
||||
|
||||
def get_attr(obj, attr: str):
|
||||
"""Retrieves a nested attribute from an object using dot notation.
|
||||
@ -1217,46 +1297,82 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
|
||||
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
||||
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
||||
if is_stream_state_dict(state_dict):
|
||||
scaled_meta = state_dict_meta(state_dict, scaled_fp8_key)
|
||||
scaled_fp8_dtype = scaled_meta.dtype
|
||||
scaled_fp8_weight_nelements = scaled_meta.numel
|
||||
else:
|
||||
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
||||
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
||||
scaled_fp8_weight_nelements = scaled_fp8_weight.nelement()
|
||||
if scaled_fp8_dtype == torch.float32:
|
||||
scaled_fp8_dtype = torch.float8_e4m3fn
|
||||
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
if scaled_fp8_weight_nelements == 2:
|
||||
full_precision_matrix_mult = True
|
||||
else:
|
||||
full_precision_matrix_mult = False
|
||||
|
||||
out_sd = {}
|
||||
layers = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k == scaled_fp8_key:
|
||||
continue
|
||||
if not k.startswith(model_prefix):
|
||||
out_sd[k] = state_dict[k]
|
||||
continue
|
||||
k_out = k
|
||||
w = state_dict.pop(k)
|
||||
layer = None
|
||||
if k_out.endswith(".scale_weight"):
|
||||
layer = k_out[:-len(".scale_weight")]
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
if k_out.endswith(".scale_input"):
|
||||
layer = k_out[:-len(".scale_input")]
|
||||
k_out = "{}.input_scale".format(layer)
|
||||
if w.item() == 1.0:
|
||||
if is_stream_state_dict(state_dict):
|
||||
key_map = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k == scaled_fp8_key:
|
||||
continue
|
||||
if not k.startswith(model_prefix):
|
||||
key_map[k] = k
|
||||
continue
|
||||
k_out = k
|
||||
layer = None
|
||||
if k_out.endswith(".scale_weight"):
|
||||
layer = k_out[:-len(".scale_weight")]
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
out_sd[k_out] = w
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
state_dict = out_sd
|
||||
if k_out.endswith(".scale_input"):
|
||||
layer = k_out[:-len(".scale_input")]
|
||||
k_out = "{}.input_scale".format(layer)
|
||||
scale_val = state_dict[k]
|
||||
if scale_val.item() == 1.0:
|
||||
continue
|
||||
|
||||
key_map[k] = k_out
|
||||
state_dict = safetensors_stream.MappedStateDict(state_dict, key_map)
|
||||
else:
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k == scaled_fp8_key:
|
||||
continue
|
||||
if not k.startswith(model_prefix):
|
||||
out_sd[k] = state_dict[k]
|
||||
continue
|
||||
k_out = k
|
||||
w = state_dict.pop(k)
|
||||
layer = None
|
||||
if k_out.endswith(".scale_weight"):
|
||||
layer = k_out[:-len(".scale_weight")]
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
if k_out.endswith(".scale_input"):
|
||||
layer = k_out[:-len(".scale_input")]
|
||||
k_out = "{}.input_scale".format(layer)
|
||||
if w.item() == 1.0:
|
||||
continue
|
||||
|
||||
out_sd[k_out] = w
|
||||
|
||||
state_dict = out_sd
|
||||
quant_metadata = {"layers": layers}
|
||||
else:
|
||||
quant_metadata = json.loads(metadata["_quantization_metadata"])
|
||||
|
||||
3
nodes.py
3
nodes.py
@ -17,7 +17,6 @@ from PIL import Image, ImageOps, ImageSequence
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
|
||||
@ -526,7 +525,7 @@ class LoadLatent:
|
||||
|
||||
def load(self, latent):
|
||||
latent_path = folder_paths.get_annotated_filepath(latent)
|
||||
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
||||
latent = comfy.utils.load_torch_file(latent_path, safe_load=True)
|
||||
multiplier = 1.0
|
||||
if "latent_format_version_0" not in latent:
|
||||
multiplier = 1.0 / 0.18215
|
||||
|
||||
@ -11,6 +11,7 @@ transformers>=4.50.3
|
||||
tokenizers>=0.13.3
|
||||
sentencepiece
|
||||
safetensors>=0.4.2
|
||||
fastsafetensors @ https://github.com/foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip
|
||||
aiohttp>=3.11.8
|
||||
yarl>=1.18.0
|
||||
pyyaml
|
||||
|
||||
182
tests-unit/utils/safetensors_stream_test.py
Normal file
182
tests-unit/utils/safetensors_stream_test.py
Normal file
@ -0,0 +1,182 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
def _write_safetensors(tmp_path, tensors):
|
||||
import safetensors.torch
|
||||
path = os.path.join(tmp_path, "test.safetensors")
|
||||
safetensors.torch.save_file(tensors, path)
|
||||
return path
|
||||
|
||||
|
||||
def test_stream_state_dict_meta_is_lazy(tmp_path, monkeypatch):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
calls = []
|
||||
|
||||
original = sd._file.read_tensor
|
||||
|
||||
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
||||
calls.append(meta)
|
||||
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
||||
|
||||
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
||||
meta = sd.meta("a")
|
||||
assert meta.shape == (2, 3)
|
||||
assert meta.dtype == torch.float32
|
||||
assert meta.numel == 6
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_stream_state_dict_getitem_loads_single_tensor(tmp_path, monkeypatch):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
calls = []
|
||||
|
||||
original = sd._file.read_tensor
|
||||
|
||||
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
||||
calls.append(meta)
|
||||
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
||||
|
||||
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
||||
_ = sd["a"]
|
||||
assert len(calls) == 1
|
||||
assert calls[0].shape == (2, 3)
|
||||
|
||||
|
||||
def test_stream_views_do_not_materialize(tmp_path, monkeypatch):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
path = _write_safetensors(tmp_path, {"prefix.a": torch.zeros((2, 3)), "other": torch.ones((4,))})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
calls = []
|
||||
|
||||
original = sd._file.read_tensor
|
||||
|
||||
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
||||
calls.append(meta)
|
||||
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
||||
|
||||
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
||||
view = comfy.utils.state_dict_prefix_replace(sd, {"prefix.": ""}, filter_keys=True)
|
||||
_ = list(view.keys())
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_stream_load_rss_small(tmp_path):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
psutil = pytest.importorskip("psutil")
|
||||
process = psutil.Process()
|
||||
size_elems = 4_000_000 # ~16MB float32
|
||||
tensor = torch.zeros((size_elems,), dtype=torch.float32)
|
||||
path = _write_safetensors(tmp_path, {"big": tensor})
|
||||
rss_before = process.memory_info().rss
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
rss_after = process.memory_info().rss
|
||||
expected_size = tensor.numel() * tensor.element_size()
|
||||
assert (rss_after - rss_before) < expected_size
|
||||
_ = sd.meta("big")
|
||||
|
||||
|
||||
def test_gds_path_errors_without_support(tmp_path, monkeypatch):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
device = torch.device("cuda")
|
||||
|
||||
if importlib.util.find_spec("fastsafetensors") is None:
|
||||
fst = None
|
||||
else:
|
||||
fst = importlib.import_module("fastsafetensors")
|
||||
|
||||
gds_available = False
|
||||
if fst is not None and torch.cuda.is_available():
|
||||
gds_supported = fst.cpp.is_gds_supported(torch.cuda.current_device())
|
||||
gds_available = bool(fst.cpp.is_cufile_found()) and gds_supported == 1
|
||||
|
||||
if not gds_available:
|
||||
with pytest.raises(RuntimeError, match="GPUDirect requested"):
|
||||
sd.get_tensor("a", device=device, allow_gds=True)
|
||||
else:
|
||||
def fail_nogds(*args, **kwargs):
|
||||
raise AssertionError("nogds path used during GDS request")
|
||||
|
||||
monkeypatch.setattr(sd._file, "_read_tensor_nogds", fail_nogds)
|
||||
t = sd.get_tensor("a", device=device, allow_gds=True)
|
||||
assert t.device.type == "cuda"
|
||||
|
||||
|
||||
def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
import comfy.disk_weights
|
||||
|
||||
prev_cache = comfy.disk_weights.CACHE.max_bytes
|
||||
prev_gds = comfy.disk_weights.ALLOW_GDS
|
||||
prev_pin = comfy.disk_weights.PIN_IF_CPU
|
||||
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
|
||||
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
|
||||
|
||||
try:
|
||||
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
model = torch.nn.Linear(4, 4, bias=True)
|
||||
comfy.utils.load_state_dict(model, sd, strict=False)
|
||||
assert model.weight.device.type == "cpu"
|
||||
assert model.weight.device.type != "meta"
|
||||
finally:
|
||||
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
||||
|
||||
|
||||
def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch):
|
||||
if importlib.util.find_spec("fastsafetensors") is None:
|
||||
pytest.skip("fastsafetensors not installed")
|
||||
import comfy.utils
|
||||
import comfy.disk_weights
|
||||
|
||||
prev_cache = comfy.disk_weights.CACHE.max_bytes
|
||||
prev_gds = comfy.disk_weights.ALLOW_GDS
|
||||
prev_pin = comfy.disk_weights.PIN_IF_CPU
|
||||
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
|
||||
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
|
||||
|
||||
try:
|
||||
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
model = torch.nn.Linear(4, 4, bias=True)
|
||||
calls = []
|
||||
|
||||
original = sd._file.read_tensor
|
||||
|
||||
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
||||
calls.append(meta)
|
||||
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
||||
|
||||
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
||||
comfy.utils.load_state_dict(model, sd, strict=True)
|
||||
assert model.weight.device.type == "meta"
|
||||
assert calls == []
|
||||
|
||||
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"))
|
||||
assert model.weight.device.type == "cpu"
|
||||
assert len(calls) == 2
|
||||
finally:
|
||||
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
||||
Loading…
Reference in New Issue
Block a user