mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Add streaming safetensors loading with disk weight cache
This commit is contained in:
parent
3cd7b32f1b
commit
f925f8fa77
57
DESIGN.md
Normal file
57
DESIGN.md
Normal file
@ -0,0 +1,57 @@
|
||||
# Disk tier safetensors streaming design audit (ComfyUI)
|
||||
|
||||
## Mandatory research audit (verified call sites)
|
||||
|
||||
### ComfyUI load path + eager materialization sites
|
||||
- `comfy/utils.py:load_torch_file` currently uses `safetensors.safe_open` and iterates all keys to build a full `sd` dict (eager tensor materialization). It also returns metadata only after reading all tensors.【F:comfy/utils.py†L58-L93】
|
||||
- `comfy/utils.py:calculate_parameters` and `weight_dtype` iterate `sd.keys()` and then access `sd[k]` to compute `nelement()`/`dtype` (loads tensors).【F:comfy/utils.py†L109-L128】
|
||||
- `comfy/utils.py:state_dict_prefix_replace` mutates dicts by `pop`+assignment (materializes if used on a streaming mapping).【F:comfy/utils.py†L135-L144】
|
||||
- `comfy/model_base.py:BaseModel.load_model_weights` builds `to_load = {}` by iterating keys and popping tensors, then passes a fully materialized dict to `load_state_dict` (RAM spike).【F:comfy/model_base.py†L301-L318】
|
||||
- `comfy/model_detection.py` reads `state_dict[key].shape` in many branches for detection (must be metadata-only). Example: `calculate_transformer_depth` and numerous `detect_unet_config` branches read shapes directly from `state_dict` values.【F:comfy/model_detection.py†L21-L200】
|
||||
- `comfy/sd.py` loads checkpoints, then slices, renames, and computes parameters/dtypes by reading tensors (e.g., `calculate_parameters`, `weight_dtype`, `process_*_state_dict`, and special scaled-FP8 conversion that builds new dicts).【F:comfy/sd.py†L1304-L1519】
|
||||
- Direct safetensors load outside `load_torch_file`: `comfy/sd1_clip.py:load_embed` and `nodes.py:LoadLatent.load` use `safetensors.torch.load_file`, bypassing the core loader.【F:comfy/sd1_clip.py†L432-L434】【F:nodes.py†L521-L529】
|
||||
|
||||
### FastSageTensors (fastsafetensors) capability audit
|
||||
- Header parsing and metadata:
|
||||
- `fastsafetensors/common.py:SafeTensorsMetadata` parses the header and builds per-tensor `TensorFrame` with `dtype`, `shape`, and `data_offsets` (no tensor allocation).【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L63-L187】
|
||||
- `TensorFrame` stores dtype/shape/offsets and supports slicing metadata.【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L238-L338】
|
||||
- GDS + no-GDS low-level readers:
|
||||
- `fastsafetensors/cpp.pyi` exposes `gds_file_reader`, `gds_file_handle`, `nogds_file_reader`, `cpu_malloc`, `gpu_malloc`, and alignment helpers such as `get_alignment_size()`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L1-L43】
|
||||
- GDS availability checks are in `fastsafetensors/cpp.pyi`: `is_gds_supported`, `is_cufile_found`, `cufile_version`, and `init_gds`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L36-L43】
|
||||
- DLPack wrapping:
|
||||
- `fastsafetensors/dlpack.py` provides `from_cuda_buffer()` which creates DLPack capsules for both CPU and GPU buffers via a device descriptor and is used for `torch.from_dlpack`.【F:../third_party/fastsafetensors-main/fastsafetensors/dlpack.py†L232-L239】
|
||||
- Torch framework interop:
|
||||
- `fastsafetensors/frameworks/_torch.py:TorchOp` provides `alloc_tensor_memory`/`free_tensor_memory`, dtype mapping, and uses `torch.from_dlpack` for wrapping raw pointers into tensors.【F:../third_party/fastsafetensors-main/fastsafetensors/frameworks/_torch.py†L131-L205】
|
||||
|
||||
### VRAM/RAM offload logic (for extension)
|
||||
- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】
|
||||
- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】
|
||||
|
||||
## Strategy summary (no coding performed yet)
|
||||
|
||||
### Streaming safetensors mapping (no full dict materialization)
|
||||
- Introduce a new module `comfy/safetensors_stream.py` with:
|
||||
- `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
|
||||
- `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
|
||||
- Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
|
||||
|
||||
### Range reads and tiering
|
||||
- Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
|
||||
- Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
|
||||
- Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
|
||||
|
||||
### Disk tier integration
|
||||
- Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
|
||||
- Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
|
||||
- Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
|
||||
|
||||
### Pipeline refactors
|
||||
- Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
|
||||
- Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
|
||||
- Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
|
||||
- Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
|
||||
- Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
|
||||
|
||||
### Tests and docs
|
||||
- Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
|
||||
- Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.
|
||||
@ -349,6 +349,14 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
|
||||
| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
|
||||
|
||||
### Disk tier for model weights
|
||||
|
||||
When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
|
||||
|
||||
If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.
|
||||
|
||||
|
||||
# Running
|
||||
|
||||
@ -29,7 +29,7 @@ class AudioEncoderModel():
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -114,6 +114,9 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
|
||||
parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
|
||||
@ -48,7 +48,7 @@ class ClipVisionModel():
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -439,7 +439,7 @@ def controlnet_config(sd, model_options={}):
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
@ -473,9 +473,9 @@ def load_controlnet_mmdit(sd, model_options={}):
|
||||
class ControlNetSD35(ControlNet):
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
if self.control_model.double_y_emb:
|
||||
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
else:
|
||||
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def copy(self):
|
||||
@ -748,9 +748,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.control_model = control_model
|
||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||
missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
|
||||
else:
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
@ -874,7 +874,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
else:
|
||||
return None
|
||||
|
||||
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
||||
missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
|
||||
if len(missing) > 0:
|
||||
logging.warning("t2i missing {}".format(missing))
|
||||
|
||||
|
||||
275
comfy/disk_weights.py
Normal file
275
comfy/disk_weights.py
Normal file
@ -0,0 +1,275 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
ALLOW_GDS = False
|
||||
PIN_IF_CPU = False
|
||||
DISK_WEIGHTS_ENABLED = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiskTensorRef:
|
||||
state_dict: object
|
||||
key: str
|
||||
meta: object
|
||||
requires_grad: bool
|
||||
is_buffer: bool
|
||||
|
||||
def load(self, device: torch.device, allow_gds: bool, pin_if_cpu: bool) -> torch.Tensor:
|
||||
dtype = getattr(self.meta, "dtype", None)
|
||||
if hasattr(self.state_dict, "get_tensor"):
|
||||
return self.state_dict.get_tensor(
|
||||
self.key,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
allow_gds=allow_gds,
|
||||
pin_if_cpu=pin_if_cpu,
|
||||
)
|
||||
tensor = self.state_dict[self.key]
|
||||
if device is not None and tensor.device != device:
|
||||
tensor = tensor.to(device=device)
|
||||
if dtype is not None and tensor.dtype != dtype:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
class DiskWeightRegistry:
|
||||
def __init__(self):
|
||||
self._registry = weakref.WeakKeyDictionary()
|
||||
|
||||
def register(self, module: torch.nn.Module, name: str, ref: DiskTensorRef):
|
||||
module_refs = self._registry.setdefault(module, {})
|
||||
module_refs[name] = ref
|
||||
|
||||
def get(self, module: torch.nn.Module) -> Optional[Dict[str, DiskTensorRef]]:
|
||||
return self._registry.get(module)
|
||||
|
||||
def has(self, module: torch.nn.Module) -> bool:
|
||||
return module in self._registry
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
module_ref: weakref.ReferenceType
|
||||
name: str
|
||||
size_bytes: int
|
||||
is_buffer: bool
|
||||
|
||||
|
||||
class DiskWeightCache:
|
||||
def __init__(self, max_bytes: int = 0):
|
||||
self.max_bytes = max_bytes
|
||||
self.current_bytes = 0
|
||||
self._entries: "collections.OrderedDict[tuple[int, str], CacheEntry]" = collections.OrderedDict()
|
||||
|
||||
def set_limit(self, max_bytes: int):
|
||||
self.max_bytes = max_bytes
|
||||
self._evict_if_needed()
|
||||
|
||||
def _entry_key(self, module: torch.nn.Module, name: str) -> tuple[int, str]:
|
||||
return (id(module), name)
|
||||
|
||||
def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool):
|
||||
if tensor.device.type != "cpu":
|
||||
return
|
||||
size_bytes = tensor.numel() * tensor.element_size()
|
||||
key = self._entry_key(module, name)
|
||||
if key in self._entries:
|
||||
entry = self._entries.pop(key)
|
||||
self.current_bytes -= entry.size_bytes
|
||||
module_ref = weakref.ref(module, self._drop_module_entries)
|
||||
self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer)
|
||||
self.current_bytes += size_bytes
|
||||
self._evict_if_needed()
|
||||
|
||||
def touch(self, module: torch.nn.Module, name: str):
|
||||
key = self._entry_key(module, name)
|
||||
if key in self._entries:
|
||||
entry = self._entries.pop(key)
|
||||
self._entries[key] = entry
|
||||
|
||||
def evict_bytes(self, bytes_to_free: int):
|
||||
freed = 0
|
||||
while self._entries and freed < bytes_to_free:
|
||||
_, entry = self._entries.popitem(last=False)
|
||||
freed += entry.size_bytes
|
||||
self.current_bytes -= entry.size_bytes
|
||||
module = entry.module_ref()
|
||||
if module is not None:
|
||||
_evict_module_weight(module, entry.name, entry.is_buffer)
|
||||
return freed
|
||||
|
||||
def _drop_module_entries(self, module_ref: weakref.ReferenceType):
|
||||
to_remove = []
|
||||
for key, entry in self._entries.items():
|
||||
if entry.module_ref is module_ref:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
entry = self._entries.pop(key)
|
||||
self.current_bytes -= entry.size_bytes
|
||||
|
||||
def _evict_if_needed(self):
|
||||
while self._entries and self.current_bytes > self.max_bytes:
|
||||
_, entry = self._entries.popitem(last=False)
|
||||
self.current_bytes -= entry.size_bytes
|
||||
module = entry.module_ref()
|
||||
if module is not None:
|
||||
_evict_module_weight(module, entry.name, entry.is_buffer)
|
||||
|
||||
|
||||
REGISTRY = DiskWeightRegistry()
|
||||
CACHE = DiskWeightCache(0)
|
||||
|
||||
|
||||
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
|
||||
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED
|
||||
ALLOW_GDS = allow_gds
|
||||
PIN_IF_CPU = pin_if_cpu
|
||||
DISK_WEIGHTS_ENABLED = enabled
|
||||
CACHE.set_limit(cache_bytes if enabled else 0)
|
||||
if not enabled:
|
||||
CACHE._entries.clear()
|
||||
CACHE.current_bytes = 0
|
||||
|
||||
|
||||
def disk_weights_enabled() -> bool:
|
||||
return DISK_WEIGHTS_ENABLED
|
||||
|
||||
|
||||
def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
|
||||
if not disk_weights_enabled():
|
||||
return
|
||||
if not hasattr(state_dict, "meta") or not hasattr(state_dict, "get_tensor"):
|
||||
return
|
||||
for name, param in module.named_parameters(recurse=True):
|
||||
key = f"{prefix}{name}" if prefix else name
|
||||
if key in state_dict:
|
||||
meta = state_dict.meta(key)
|
||||
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
|
||||
REGISTRY.register(module, name, ref)
|
||||
if param.device.type == "cpu":
|
||||
CACHE.record(module, name, param, is_buffer=False)
|
||||
for name, buf in module.named_buffers(recurse=True):
|
||||
key = f"{prefix}{name}" if prefix else name
|
||||
if key in state_dict and buf is not None:
|
||||
meta = state_dict.meta(key)
|
||||
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
|
||||
REGISTRY.register(module, name, ref)
|
||||
if buf.device.type == "cpu":
|
||||
CACHE.record(module, name, buf, is_buffer=True)
|
||||
|
||||
|
||||
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
ref = REGISTRY.get(module)
|
||||
if not ref or name not in ref:
|
||||
return
|
||||
disk_ref = ref[name]
|
||||
shape = getattr(disk_ref.meta, "shape", None)
|
||||
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||
if shape is None or dtype is None:
|
||||
return
|
||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||
if is_buffer:
|
||||
module._buffers[name] = meta_tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
|
||||
|
||||
def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
|
||||
def check(obj):
|
||||
if torch.is_tensor(obj):
|
||||
return obj.device
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
dev = check(item)
|
||||
if dev is not None:
|
||||
return dev
|
||||
if isinstance(obj, dict):
|
||||
for item in obj.values():
|
||||
dev = check(item)
|
||||
if dev is not None:
|
||||
return dev
|
||||
return None
|
||||
|
||||
dev = check(args)
|
||||
if dev is not None:
|
||||
return dev
|
||||
return check(kwargs)
|
||||
|
||||
|
||||
def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device):
|
||||
refs = REGISTRY.get(module)
|
||||
if not refs:
|
||||
return
|
||||
for name, disk_ref in refs.items():
|
||||
if name in module._parameters:
|
||||
current = module._parameters[name]
|
||||
is_buffer = False
|
||||
elif name in module._buffers:
|
||||
current = module._buffers[name]
|
||||
is_buffer = True
|
||||
else:
|
||||
continue
|
||||
if current is None:
|
||||
continue
|
||||
if current.device.type != "meta":
|
||||
if current.device.type == "cpu":
|
||||
CACHE.touch(module, name)
|
||||
continue
|
||||
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu":
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
|
||||
|
||||
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
||||
if not REGISTRY.has(module):
|
||||
return
|
||||
if getattr(module, "comfy_cast_weights", False):
|
||||
target_device = torch.device("cpu")
|
||||
else:
|
||||
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||
ensure_module_materialized(module, target_device)
|
||||
|
||||
|
||||
def attach_disk_weight_hooks(model: torch.nn.Module):
|
||||
if not disk_weights_enabled():
|
||||
return
|
||||
for module in model.modules():
|
||||
if getattr(module, "_disk_weight_hook_attached", False):
|
||||
continue
|
||||
module.register_forward_pre_hook(disk_weight_pre_hook)
|
||||
module._disk_weight_hook_attached = True
|
||||
|
||||
|
||||
def evict_ram_cache(bytes_to_free: int):
|
||||
if bytes_to_free <= 0:
|
||||
return 0
|
||||
return CACHE.evict_bytes(bytes_to_free)
|
||||
@ -3,6 +3,7 @@ import torch
|
||||
from torch import nn
|
||||
from .ldm.modules.attention import CrossAttention, FeedForward
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
ops = comfy.ops.manual_cast
|
||||
|
||||
|
||||
@ -282,7 +283,7 @@ def load_gligen(sd):
|
||||
|
||||
gated = GatedSelfAttentionDense(
|
||||
query_dim, key_dim, n_heads, d_head)
|
||||
gated.load_state_dict(n_sd, strict=False)
|
||||
comfy.utils.load_state_dict(gated, n_sd, strict=False)
|
||||
output_list.append(gated)
|
||||
|
||||
if "position_net.null_positive_feature" in sd_k:
|
||||
@ -293,7 +294,7 @@ def load_gligen(sd):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.position_net = PositionNet(in_dim, out_dim)
|
||||
w.load_state_dict(sd, strict=False)
|
||||
comfy.utils.load_state_dict(w, sd, strict=False)
|
||||
|
||||
gligen = Gligen(output_list, w.position_net, key_dim)
|
||||
return gligen
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
@ -112,7 +113,7 @@ class HunyuanVideo15SRModel():
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torchaudio
|
||||
|
||||
import comfy.model_management
|
||||
@ -153,8 +154,8 @@ class AudioVAE(torch.nn.Module):
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
|
||||
comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
|
||||
|
||||
autoencoder_config = self.autoencoder.get_config()
|
||||
self.normalizer = AudioLatentNormalizer(
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torch.nn as nn
|
||||
|
||||
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||
@ -152,7 +153,7 @@ class VAE(nn.Module):
|
||||
return dec, posterior
|
||||
|
||||
def load_weights(self, src_dict) -> None:
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
comfy.utils.load_state_dict(self, src_dict, strict=True)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
@ -355,4 +356,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
|
||||
if name == '44k':
|
||||
return VAE_44k(**kwargs)
|
||||
raise ValueError(f'Unknown model: {name}')
|
||||
|
||||
|
||||
@ -299,20 +299,14 @@ class BaseModel(torch.nn.Module):
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
to_load = {}
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k.startswith(unet_prefix):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = utils.state_dict_prefix_replace(sd, {unet_prefix: ""}, filter_keys=True)
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.warning("unet unexpected: {}".format(u))
|
||||
del to_load
|
||||
return self
|
||||
|
||||
def process_latent_in(self, latent):
|
||||
@ -751,8 +745,8 @@ class StableAudio1(BaseModel):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
||||
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
|
||||
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
|
||||
utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
|
||||
utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
@ -19,6 +19,9 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def sd_shape(state_dict, key):
|
||||
return comfy.utils.state_dict_meta(state_dict, key).shape
|
||||
|
||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
@ -27,8 +30,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
|
||||
if len(transformer_keys) > 0:
|
||||
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||
context_dim = sd_shape(state_dict, '{}0.attn2.to_k.weight'.format(transformer_prefix))[1]
|
||||
use_linear_in_transformer = len(sd_shape(state_dict, '{}1.proj_in.weight'.format(prefix))) == 2
|
||||
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
||||
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
||||
@ -39,27 +42,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
||||
unet_config = {}
|
||||
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
||||
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
|
||||
unet_config["in_channels"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[1]
|
||||
patch_size = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[2]
|
||||
unet_config["patch_size"] = patch_size
|
||||
final_layer = '{}final_layer.linear.weight'.format(key_prefix)
|
||||
if final_layer in state_dict:
|
||||
unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
|
||||
unet_config["out_channels"] = sd_shape(state_dict, final_layer)[0] // (patch_size * patch_size)
|
||||
|
||||
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
|
||||
unet_config["depth"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0] // 64
|
||||
unet_config["input_size"] = None
|
||||
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
|
||||
if y_key in state_dict_keys:
|
||||
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
|
||||
unet_config["adm_in_channels"] = sd_shape(state_dict, y_key)[1]
|
||||
|
||||
context_key = '{}context_embedder.weight'.format(key_prefix)
|
||||
if context_key in state_dict_keys:
|
||||
in_features = state_dict[context_key].shape[1]
|
||||
out_features = state_dict[context_key].shape[0]
|
||||
in_features = sd_shape(state_dict, context_key)[1]
|
||||
out_features = sd_shape(state_dict, context_key)[0]
|
||||
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
|
||||
num_patches_key = '{}pos_embed'.format(key_prefix)
|
||||
if num_patches_key in state_dict_keys:
|
||||
num_patches = state_dict[num_patches_key].shape[1]
|
||||
num_patches = sd_shape(state_dict, num_patches_key)[1]
|
||||
unet_config["num_patches"] = num_patches
|
||||
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
|
||||
|
||||
@ -83,23 +86,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
||||
if text_mapper_name in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'c'
|
||||
w = state_dict[text_mapper_name]
|
||||
if w.shape[0] == 1536: #stage c lite
|
||||
w_shape = sd_shape(state_dict, text_mapper_name)
|
||||
if w_shape[0] == 1536: #stage c lite
|
||||
unet_config['c_cond'] = 1536
|
||||
unet_config['c_hidden'] = [1536, 1536]
|
||||
unet_config['nhead'] = [24, 24]
|
||||
unet_config['blocks'] = [[4, 12], [12, 4]]
|
||||
elif w.shape[0] == 2048: #stage c full
|
||||
elif w_shape[0] == 2048: #stage c full
|
||||
unet_config['c_cond'] = 2048
|
||||
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
|
||||
unet_config['stable_cascade_stage'] = 'b'
|
||||
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
|
||||
if w.shape[-1] == 640:
|
||||
w_shape = sd_shape(state_dict, '{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix))
|
||||
if w_shape[-1] == 640:
|
||||
unet_config['c_hidden'] = [320, 640, 1280, 1280]
|
||||
unet_config['nhead'] = [-1, -1, 20, 20]
|
||||
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
|
||||
elif w.shape[-1] == 576: #stage b lite
|
||||
elif w_shape[-1] == 576: #stage b lite
|
||||
unet_config['c_hidden'] = [320, 576, 1152, 1152]
|
||||
unet_config['nhead'] = [-1, 9, 18, 18]
|
||||
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
||||
@ -113,8 +116,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
|
||||
unet_config = {}
|
||||
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
|
||||
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
|
||||
unet_config["max_seq"] = sd_shape(state_dict, '{}positional_encoding'.format(key_prefix))[1]
|
||||
unet_config["cond_seq_dim"] = sd_shape(state_dict, '{}cond_seq_linear.weight'.format(key_prefix))[1]
|
||||
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
|
||||
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
|
||||
unet_config["n_double_layers"] = double_layers
|
||||
@ -125,10 +128,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "hydit"
|
||||
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
|
||||
unet_config["hidden_size"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0]
|
||||
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
|
||||
unet_config["mlp_ratio"] = 4.3637
|
||||
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
|
||||
if sd_shape(state_dict, '{}extra_embedder.0.weight'.format(key_prefix))[1] == 3968:
|
||||
unet_config["size_cond"] = True
|
||||
unet_config["use_style_cond"] = True
|
||||
unet_config["image_model"] = "hydit1"
|
||||
@ -136,12 +139,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
||||
dit_config = {}
|
||||
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
|
||||
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
|
||||
in_w_shape = sd_shape(state_dict, '{}img_in.proj.weight'.format(key_prefix))
|
||||
out_w_shape = sd_shape(state_dict, '{}final_layer.linear.weight'.format(key_prefix))
|
||||
dit_config["image_model"] = "hunyuan_video"
|
||||
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
|
||||
dit_config["patch_size"] = list(in_w.shape[2:])
|
||||
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
|
||||
dit_config["in_channels"] = in_w_shape[1] #SkyReels img2video has 32 input channels
|
||||
dit_config["patch_size"] = list(in_w_shape[2:])
|
||||
dit_config["out_channels"] = out_w_shape[0] // math.prod(dit_config["patch_size"])
|
||||
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
|
||||
dit_config["vec_in_dim"] = 768
|
||||
else:
|
||||
@ -157,10 +160,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["meanflow"] = False
|
||||
|
||||
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["hidden_size"] = in_w.shape[0]
|
||||
dit_config["context_in_dim"] = sd_shape(state_dict, '{}txt_in.input_embedder.weight'.format(key_prefix))[1]
|
||||
dit_config["hidden_size"] = in_w_shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = in_w.shape[0] // 128
|
||||
dit_config["num_heads"] = in_w_shape[0] // 128
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["theta"] = 256
|
||||
@ -179,7 +182,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["use_cond_type_embedding"] = False
|
||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["vision_in_dim"] = sd_shape(state_dict, '{}vision_in.proj.0.weight'.format(key_prefix))[0]
|
||||
dit_config["meanflow_sum"] = True
|
||||
else:
|
||||
dit_config["vision_in_dim"] = None
|
||||
@ -221,19 +224,19 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
w = state_dict[in_key]
|
||||
dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
w_shape = sd_shape(state_dict, in_key)
|
||||
dit_config["in_channels"] = w_shape[1] // (patch_size * patch_size)
|
||||
dit_config["hidden_size"] = w_shape[0]
|
||||
|
||||
txt_in_key = "{}txt_in.weight".format(key_prefix)
|
||||
if txt_in_key in state_dict_keys:
|
||||
w = state_dict[txt_in_key]
|
||||
dit_config["context_in_dim"] = w.shape[1]
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
w_shape = sd_shape(state_dict, txt_in_key)
|
||||
dit_config["context_in_dim"] = w_shape[1]
|
||||
dit_config["hidden_size"] = w_shape[0]
|
||||
|
||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||
if vec_in_key in state_dict_keys:
|
||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||
dit_config["vec_in_dim"] = sd_shape(state_dict, vec_in_key)[1]
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
|
||||
@ -307,7 +310,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
||||
shape = sd_shape(state_dict, '{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix))
|
||||
dit_config["attention_head_dim"] = shape[0] // 32
|
||||
dit_config["cross_attention_dim"] = shape[1]
|
||||
if metadata is not None and "config" in metadata:
|
||||
@ -350,11 +353,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
y_key = "{}y_embedder.y_embedding".format(key_prefix)
|
||||
if y_key in state_dict_keys:
|
||||
dit_config["model_max_length"] = state_dict[y_key].shape[0]
|
||||
dit_config["model_max_length"] = sd_shape(state_dict, y_key)[0]
|
||||
|
||||
pe_key = "{}pos_embed".format(key_prefix)
|
||||
if pe_key in state_dict_keys:
|
||||
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
|
||||
dit_config["input_size"] = int(math.sqrt(sd_shape(state_dict, pe_key)[1])) * patch_size
|
||||
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
|
||||
|
||||
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
|
||||
@ -373,11 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
concat_padding_mask = True
|
||||
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_spatial"] = 2
|
||||
dit_config["patch_temporal"] = 1
|
||||
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["model_channels"] = sd_shape(state_dict, '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix))[0]
|
||||
dit_config["block_config"] = "FA-CA-MLP"
|
||||
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
dit_config["pos_emb_cls"] = "rope3d"
|
||||
@ -416,9 +419,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["in_channels"] = 16
|
||||
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
|
||||
dit_config["dim"] = w.shape[0]
|
||||
dit_config["cap_feat_dim"] = w.shape[1]
|
||||
w_shape = sd_shape(state_dict, '{}cap_embedder.1.weight'.format(key_prefix))
|
||||
dit_config["dim"] = w_shape[0]
|
||||
dit_config["cap_feat_dim"] = w_shape[1]
|
||||
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
dit_config["qk_norm"] = True
|
||||
|
||||
@ -429,9 +432,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None: # NewBie
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
ctd_key = '{}clip_text_pooled_proj.0.weight'.format(key_prefix)
|
||||
if ctd_key in state_dict_keys: # NewBie
|
||||
dit_config["clip_text_dim"] = sd_shape(state_dict, ctd_key)[0]
|
||||
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
@ -450,12 +453,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
|
||||
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
|
||||
dim = sd_shape(state_dict, '{}head.modulation'.format(key_prefix))[-1]
|
||||
out_dim = sd_shape(state_dict, '{}head.head.weight'.format(key_prefix))[0] // 4
|
||||
dit_config["dim"] = dim
|
||||
dit_config["out_dim"] = out_dim
|
||||
dit_config["num_heads"] = dim // 128
|
||||
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["ffn_dim"] = sd_shape(state_dict, '{}blocks.0.ffn.0.weight'.format(key_prefix))[0]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["patch_size"] = (1, 2, 2)
|
||||
dit_config["freq_dim"] = 256
|
||||
@ -463,10 +466,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["cross_attn_norm"] = True
|
||||
dit_config["eps"] = 1e-6
|
||||
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["in_dim"] = sd_shape(state_dict, '{}patch_embedding.weight'.format(key_prefix))[1]
|
||||
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "vace"
|
||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["vace_in_dim"] = sd_shape(state_dict, '{}vace_patch_embedding.weight'.format(key_prefix))[1]
|
||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
@ -484,22 +487,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "i2v"
|
||||
else:
|
||||
dit_config["model_type"] = "t2v"
|
||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||
if flf_weight is not None:
|
||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||
flf_key = '{}img_emb.emb_pos'.format(key_prefix)
|
||||
if flf_key in state_dict_keys:
|
||||
dit_config["flf_pos_embed_token_number"] = sd_shape(state_dict, flf_key)[1]
|
||||
|
||||
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
|
||||
if ref_conv_weight is not None:
|
||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||
ref_conv_key = '{}ref_conv.weight'.format(key_prefix)
|
||||
if ref_conv_key in state_dict_keys:
|
||||
dit_config["in_dim_ref_conv"] = sd_shape(state_dict, ref_conv_key)[1]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
||||
in_shape = sd_shape(state_dict, '{}latent_in.weight'.format(key_prefix))
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2"
|
||||
dit_config["in_channels"] = in_shape[1]
|
||||
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["context_in_dim"] = sd_shape(state_dict, '{}cond_in.weight'.format(key_prefix))[1]
|
||||
dit_config["hidden_size"] = in_shape[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 16
|
||||
@ -513,9 +516,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan3d2_1"
|
||||
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
||||
dit_config["in_channels"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[1]
|
||||
dit_config["context_dim"] = 1024
|
||||
dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
|
||||
dit_config["hidden_size"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[0]
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 16
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
|
||||
@ -549,11 +552,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["max_img_w"] = 240
|
||||
dit_config["max_frames"] = 128
|
||||
concat_padding_mask = True
|
||||
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_spatial"] = 2
|
||||
dit_config["patch_temporal"] = 1
|
||||
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["model_channels"] = sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[0]
|
||||
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
dit_config["crossattn_emb_channels"] = 1024
|
||||
dit_config["pos_emb_cls"] = "rope3d"
|
||||
@ -617,7 +620,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["in_channels"] = sd_shape(state_dict, '{}img_in.weight'.format(key_prefix))[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||
@ -628,7 +631,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
dit_config = {}
|
||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
model_dim = sd_shape(state_dict, '{}visual_embeddings.in_layer.bias'.format(key_prefix))[0]
|
||||
dit_config["model_dim"] = model_dim
|
||||
if model_dim in [4096, 2560]: # pro video and lite image
|
||||
dit_config["axes_dims"] = (32, 48, 48)
|
||||
@ -636,10 +639,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
|
||||
elif model_dim == 1792: # lite video
|
||||
dit_config["axes_dims"] = (16, 24, 24)
|
||||
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["time_dim"] = sd_shape(state_dict, '{}time_embeddings.in_layer.bias'.format(key_prefix))[0]
|
||||
dit_config["image_model"] = "kandinsky5"
|
||||
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["ff_dim"] = sd_shape(state_dict, '{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix))[0]
|
||||
dit_config["visual_embed_dim"] = sd_shape(state_dict, '{}visual_embeddings.in_layer.weight'.format(key_prefix))[1]
|
||||
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
@ -657,16 +660,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
y_input = '{}label_emb.0.0.weight'.format(key_prefix)
|
||||
if y_input in state_dict_keys:
|
||||
unet_config["num_classes"] = "sequential"
|
||||
unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
|
||||
unet_config["adm_in_channels"] = sd_shape(state_dict, y_input)[1]
|
||||
else:
|
||||
unet_config["adm_in_channels"] = None
|
||||
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
model_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[0]
|
||||
in_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[1]
|
||||
|
||||
out_key = '{}out.2.weight'.format(key_prefix)
|
||||
if out_key in state_dict:
|
||||
out_channels = state_dict[out_key].shape[0]
|
||||
out_channels = sd_shape(state_dict, out_key)[0]
|
||||
else:
|
||||
out_channels = 4
|
||||
|
||||
@ -713,7 +716,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
|
||||
if res_block_prefix in block_keys:
|
||||
last_res_blocks += 1
|
||||
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
|
||||
last_channel_mult = sd_shape(state_dict, "{}0.out_layers.3.weight".format(prefix))[0] // model_channels
|
||||
|
||||
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
@ -867,7 +870,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
||||
transformer_depth.append(transformer_count)
|
||||
if transformer_count > 0:
|
||||
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
|
||||
match["context_dim"] = sd_shape(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab))[1]
|
||||
|
||||
attn_res *= 2
|
||||
if attn_blocks == 0:
|
||||
@ -876,13 +879,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
|
||||
match["transformer_depth"] = transformer_depth
|
||||
|
||||
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
||||
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
||||
match["model_channels"] = sd_shape(state_dict, "conv_in.weight")[0]
|
||||
match["in_channels"] = sd_shape(state_dict, "conv_in.weight")[1]
|
||||
match["adm_in_channels"] = None
|
||||
if "class_embedding.linear_1.weight" in state_dict:
|
||||
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
||||
match["adm_in_channels"] = sd_shape(state_dict, "class_embedding.linear_1.weight")[1]
|
||||
elif "add_embedding.linear_1.weight" in state_dict:
|
||||
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||
match["adm_in_channels"] = sd_shape(state_dict, "add_embedding.linear_1.weight")[1]
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
@ -1023,11 +1026,11 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
elif 'x_embedder.weight' in state_dict: #Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||
hidden_size = sd_shape(state_dict, "x_embedder.bias")[0]
|
||||
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||
depth = sd_shape(state_dict, "pos_embed.proj.weight")[0] // 64
|
||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -27,6 +27,7 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
import comfy.disk_weights
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -583,6 +584,8 @@ def minimum_inference_memory():
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
||||
comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -1124,6 +1127,16 @@ if not args.disable_pinned_memory:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
WEIGHTS_RAM_CACHE_BYTES = 0
|
||||
WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
|
||||
if args.weights_ram_cache_gb is not None:
|
||||
WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
|
||||
comfy.disk_weights.configure(
|
||||
WEIGHTS_RAM_CACHE_BYTES,
|
||||
allow_gds=WEIGHTS_GDS_ENABLED,
|
||||
pin_if_cpu=not args.disable_pinned_memory,
|
||||
)
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
|
||||
def discard_cuda_async_error():
|
||||
|
||||
@ -34,6 +34,7 @@ import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
import comfy.disk_weights
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
@ -269,6 +270,8 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
comfy.disk_weights.attach_disk_weight_hooks(self.model)
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@ -1356,4 +1359,3 @@ class ModelPatcher:
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
self.detach(unpatch_all=False)
|
||||
|
||||
|
||||
799
comfy/safetensors_stream.py
Normal file
799
comfy/safetensors_stream.py
Normal file
@ -0,0 +1,799 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Comfy
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_FST_MODULE = None
|
||||
_FST_LOCK = threading.Lock()
|
||||
_FST_LOADED = False
|
||||
_GDS_INITIALIZED = False
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
def _require_fastsafetensors():
|
||||
global _FST_MODULE
|
||||
with _FST_LOCK:
|
||||
if _FST_MODULE is None:
|
||||
if importlib.util.find_spec("fastsafetensors") is None:
|
||||
raise ImportError(
|
||||
"fastsafetensors is required for safetensors streaming. "
|
||||
"Install it with: pip install 'fastsafetensors @ https://github.com/"
|
||||
"foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip'"
|
||||
)
|
||||
_FST_MODULE = importlib.import_module("fastsafetensors")
|
||||
return _FST_MODULE
|
||||
|
||||
|
||||
def _init_fastsafetensors_lib():
|
||||
global _FST_LOADED
|
||||
fst = _require_fastsafetensors()
|
||||
if not _FST_LOADED:
|
||||
fst.cpp.load_library_functions()
|
||||
_FST_LOADED = True
|
||||
return fst
|
||||
|
||||
|
||||
def _init_gds():
|
||||
global _GDS_INITIALIZED
|
||||
fst = _init_fastsafetensors_lib()
|
||||
if not _GDS_INITIALIZED:
|
||||
if fst.cpp.init_gds() != 0:
|
||||
raise RuntimeError("fastsafetensors init_gds() failed")
|
||||
_GDS_INITIALIZED = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorMeta:
|
||||
dtype: torch.dtype
|
||||
shape: Tuple[int, ...]
|
||||
numel: int
|
||||
nbytes: int
|
||||
data_offsets: Tuple[int, int]
|
||||
filename: str
|
||||
fst_dtype: object
|
||||
strides: Tuple[int, ...]
|
||||
|
||||
|
||||
class SafeTensorIndex:
|
||||
def __init__(self, filename: str):
|
||||
fst = _init_fastsafetensors_lib()
|
||||
framework = fst.frameworks.get_framework_op("pytorch")
|
||||
metadata = fst.common.SafeTensorsMetadata.from_file(filename, framework)
|
||||
self._filename = filename
|
||||
self._metadata = metadata
|
||||
self._framework = framework
|
||||
from fastsafetensors.frameworks import _torch as fst_torch
|
||||
self._dtype_map = fst_torch.dtype_convert
|
||||
self._tensor_meta: Dict[str, TensorMeta] = {}
|
||||
for key, frame in metadata.tensors.items():
|
||||
torch_dtype = self._dtype_map.get(frame.dtype, None)
|
||||
if torch_dtype is None:
|
||||
raise ValueError(f"Unsupported safetensors dtype {frame.dtype} in {filename}")
|
||||
numel = 1
|
||||
for s in frame.shape:
|
||||
numel *= s
|
||||
nbytes = numel * framework.get_dtype_size(frame.dtype)
|
||||
self._tensor_meta[key] = TensorMeta(
|
||||
dtype=torch_dtype,
|
||||
shape=tuple(frame.shape),
|
||||
numel=numel,
|
||||
nbytes=nbytes,
|
||||
data_offsets=(frame.data_offsets[0], frame.data_offsets[1]),
|
||||
filename=filename,
|
||||
fst_dtype=frame.dtype,
|
||||
strides=tuple(frame.strides),
|
||||
)
|
||||
|
||||
def keys(self) -> Iterable[str]:
|
||||
return self._tensor_meta.keys()
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
return key in self._tensor_meta
|
||||
|
||||
def meta(self, key: str) -> TensorMeta:
|
||||
return self._tensor_meta[key]
|
||||
|
||||
def metadata(self):
|
||||
return self._metadata.metadata
|
||||
|
||||
@property
|
||||
def header_length(self) -> int:
|
||||
return self._metadata.header_length
|
||||
|
||||
@property
|
||||
def size_bytes(self) -> int:
|
||||
return self._metadata.size_bytes
|
||||
|
||||
|
||||
class _SafeTensorFile:
|
||||
def __init__(self, filename: str, index: SafeTensorIndex):
|
||||
self.filename = filename
|
||||
self.index = index
|
||||
self._fd: Optional[int] = None
|
||||
self._gds_handle = None
|
||||
self._gds_reader = None
|
||||
self._nogds_reader = None
|
||||
self._refcount = 1
|
||||
|
||||
def acquire(self) -> "_SafeTensorFile":
|
||||
self._refcount += 1
|
||||
return self
|
||||
|
||||
def release(self):
|
||||
self._refcount -= 1
|
||||
if self._refcount <= 0:
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
if self._fd is not None:
|
||||
os.close(self._fd)
|
||||
self._fd = None
|
||||
self._gds_handle = None
|
||||
|
||||
def _ensure_fd(self) -> int:
|
||||
if self._fd is None:
|
||||
self._fd = os.open(self.filename, os.O_RDONLY, 0o644)
|
||||
return self._fd
|
||||
|
||||
def _ensure_nogds_reader(self, use_cuda: bool):
|
||||
fst = _init_fastsafetensors_lib()
|
||||
if self._nogds_reader is None:
|
||||
self._nogds_reader = fst.cpp.nogds_file_reader(
|
||||
False, 16 * 1024, 16, use_cuda
|
||||
)
|
||||
return self._nogds_reader
|
||||
|
||||
def _ensure_gds_reader(self, use_cuda: bool):
|
||||
fst = _init_fastsafetensors_lib()
|
||||
if self._gds_reader is None:
|
||||
self._gds_reader = fst.cpp.gds_file_reader(16, use_cuda)
|
||||
return self._gds_reader
|
||||
|
||||
def _ensure_gds_handle(self, use_cuda: bool):
|
||||
if self._gds_handle is None:
|
||||
fst = _init_fastsafetensors_lib()
|
||||
framework = fst.frameworks.get_framework_op("pytorch")
|
||||
o_direct = _get_gds_o_direct(framework)
|
||||
self._gds_handle = fst.cpp.gds_file_handle(self.filename, o_direct, use_cuda)
|
||||
return self._gds_handle
|
||||
|
||||
def read_tensor(
|
||||
self,
|
||||
meta: TensorMeta,
|
||||
device: torch.device,
|
||||
dtype: Optional[torch.dtype],
|
||||
allow_gds: bool,
|
||||
pin_if_cpu: bool,
|
||||
) -> torch.Tensor:
|
||||
fst = _init_fastsafetensors_lib()
|
||||
framework = fst.frameworks.get_framework_op("pytorch")
|
||||
device_is_cuda = device.type == "cuda"
|
||||
if device_is_cuda and allow_gds:
|
||||
_ensure_gds_ready(device)
|
||||
tensor = self._read_tensor_gds(
|
||||
fst, framework, meta, device, dtype
|
||||
)
|
||||
return tensor
|
||||
|
||||
cpu_tensor = self._read_tensor_nogds(
|
||||
fst, framework, meta, torch.device("cpu"), dtype
|
||||
)
|
||||
if device_is_cuda:
|
||||
if pin_if_cpu:
|
||||
cpu_tensor = cpu_tensor.pin_memory()
|
||||
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
|
||||
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
|
||||
return gpu_tensor
|
||||
return cpu_tensor
|
||||
|
||||
def _aligned_range(self, abs_start: int, length: int) -> Tuple[int, int, int]:
|
||||
fst = _init_fastsafetensors_lib()
|
||||
align = fst.cpp.get_alignment_size()
|
||||
aligned_offset = (abs_start // align) * align
|
||||
head = abs_start - aligned_offset
|
||||
aligned_length = length + head
|
||||
tail = aligned_length % align
|
||||
if tail:
|
||||
aligned_length += align - tail
|
||||
return aligned_offset, aligned_length, head
|
||||
|
||||
def _read_tensor_nogds(
|
||||
self,
|
||||
fst,
|
||||
framework,
|
||||
meta: TensorMeta,
|
||||
device: torch.device,
|
||||
dtype: Optional[torch.dtype],
|
||||
) -> torch.Tensor:
|
||||
fd = self._ensure_fd()
|
||||
reader = self._ensure_nogds_reader(use_cuda=False)
|
||||
abs_start = self.index.header_length + meta.data_offsets[0]
|
||||
length = meta.data_offsets[1] - meta.data_offsets[0]
|
||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
|
||||
ptr_align = framework.get_device_ptr_align()
|
||||
buffer_length = aligned_length + ptr_align
|
||||
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
|
||||
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
|
||||
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
|
||||
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
|
||||
if reader.wait_read(req) < 0:
|
||||
fst.cpp.cpu_free(buf_ptr)
|
||||
raise RuntimeError("nogds_file_reader read failed")
|
||||
owner = _BufferOwner(lambda: fst.cpp.cpu_free(buf_ptr))
|
||||
tensor = _dlpack_tensor_from_buffer(
|
||||
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
|
||||
)
|
||||
if dtype is not None and dtype != tensor.dtype:
|
||||
_validate_dtype_conversion(tensor.dtype, dtype)
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
return tensor
|
||||
|
||||
def _read_tensor_gds(
|
||||
self,
|
||||
fst,
|
||||
framework,
|
||||
meta: TensorMeta,
|
||||
device: torch.device,
|
||||
dtype: Optional[torch.dtype],
|
||||
) -> torch.Tensor:
|
||||
reader = self._ensure_gds_reader(use_cuda=True)
|
||||
handle = self._ensure_gds_handle(use_cuda=True)
|
||||
abs_start = self.index.header_length + meta.data_offsets[0]
|
||||
length = meta.data_offsets[1] - meta.data_offsets[0]
|
||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
|
||||
ptr_align = framework.get_device_ptr_align()
|
||||
buffer_length = aligned_length + ptr_align
|
||||
fst_device = _fst_device_from_torch(fst, device)
|
||||
gbuf = framework.alloc_tensor_memory(buffer_length, fst_device)
|
||||
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
|
||||
file_length = self.index.size_bytes
|
||||
req = reader.submit_read(
|
||||
handle, gbuf, aligned_offset, aligned_length, ptr_off, file_length
|
||||
)
|
||||
if reader.wait_read(req) < 0:
|
||||
framework.free_tensor_memory(gbuf, fst_device)
|
||||
raise RuntimeError("gds_file_reader read failed")
|
||||
owner = _BufferOwner(lambda: framework.free_tensor_memory(gbuf, fst_device))
|
||||
tensor = _dlpack_tensor_from_buffer(
|
||||
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
|
||||
)
|
||||
if dtype is not None and dtype != tensor.dtype:
|
||||
_validate_dtype_conversion(tensor.dtype, dtype)
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
def _fst_device_from_torch(fst, device: torch.device):
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
return fst.st_types.Device.from_str(f"cuda:{device.index}")
|
||||
return fst.st_types.Device.from_str(device.type)
|
||||
|
||||
|
||||
class _BufferOwner:
|
||||
def __init__(self, free_fn):
|
||||
self._free_fn = free_fn
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self._free_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _dlpack_tensor_from_buffer(
|
||||
fst,
|
||||
framework,
|
||||
ptr: int,
|
||||
meta: TensorMeta,
|
||||
device: torch.device,
|
||||
owner: Optional[_BufferOwner],
|
||||
) -> torch.Tensor:
|
||||
disk_dtype = framework.as_workaround_dtype(meta.fst_dtype)
|
||||
dev = _fst_device_from_torch(fst, device)
|
||||
dl_tensor = fst.dlpack.from_cuda_buffer(ptr, list(meta.shape), list(meta.strides), disk_dtype, dev)
|
||||
torch_tensor = framework.from_dlpack(dl_tensor, dev, disk_dtype).real_tensor
|
||||
if disk_dtype != meta.fst_dtype:
|
||||
torch_tensor = torch_tensor.view(meta.dtype)
|
||||
if owner is not None:
|
||||
torch_tensor._comfy_disk_buffer_owner = owner
|
||||
return torch_tensor
|
||||
|
||||
|
||||
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
|
||||
if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
|
||||
raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
|
||||
|
||||
|
||||
def _get_gds_o_direct(framework) -> bool:
|
||||
cuda_ver = framework.get_cuda_ver()
|
||||
if cuda_ver and cuda_ver != "0.0":
|
||||
ver_parts = cuda_ver.split("-", 1)
|
||||
if len(ver_parts) == 2:
|
||||
cudavers = list(map(int, ver_parts[1].split(".")))
|
||||
if ver_parts[0] == "cuda":
|
||||
return not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2))
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def _ensure_gds_ready(device: torch.device):
|
||||
fst = _init_fastsafetensors_lib()
|
||||
if not fst.common.is_gpu_found():
|
||||
raise RuntimeError(
|
||||
"GPUDirect requested but GPU runtime library is missing. "
|
||||
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
||||
)
|
||||
gds_supported = fst.cpp.is_gds_supported(device.index if device.index is not None else 0)
|
||||
if gds_supported < 0:
|
||||
raise RuntimeError(
|
||||
"GPUDirect requested but is_gds_supported() failed. "
|
||||
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
||||
)
|
||||
if not fst.cpp.is_cufile_found():
|
||||
raise RuntimeError(
|
||||
"GPUDirect requested but libcufile is missing. "
|
||||
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
||||
)
|
||||
if gds_supported == 0:
|
||||
raise RuntimeError(
|
||||
"GPUDirect requested but GDS is unsupported on this platform. "
|
||||
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
||||
)
|
||||
_init_gds()
|
||||
|
||||
|
||||
class StreamStateDict(collections.abc.MutableMapping):
|
||||
is_stream_state_dict = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: SafeTensorIndex,
|
||||
file: _SafeTensorFile,
|
||||
device: torch.device,
|
||||
allow_gds: bool = False,
|
||||
):
|
||||
self._index = index
|
||||
self._file = file
|
||||
self._device = device
|
||||
self._allow_gds = allow_gds
|
||||
self._overrides: Dict[str, torch.Tensor] = {}
|
||||
self._deleted: set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename: str, device: torch.device, allow_gds: bool = False) -> "StreamStateDict":
|
||||
index = SafeTensorIndex(filename)
|
||||
file = _SafeTensorFile(filename, index)
|
||||
return cls(index, file, device, allow_gds=allow_gds)
|
||||
|
||||
def close(self):
|
||||
if self._file is not None:
|
||||
self._file.release()
|
||||
self._file = None
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def meta(self, key: str) -> TensorMeta:
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
numel = t.numel()
|
||||
return TensorMeta(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
data_offsets=(0, numel * t.element_size()),
|
||||
filename="<override>",
|
||||
fst_dtype=None,
|
||||
strides=tuple(t.stride()),
|
||||
)
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
return self._index.meta(key)
|
||||
|
||||
def get_tensor(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
allow_gds: Optional[bool] = None,
|
||||
pin_if_cpu: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
if device is not None and t.device != device:
|
||||
t = t.to(device=device)
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
_validate_dtype_conversion(t.dtype, dtype)
|
||||
t = t.to(dtype=dtype)
|
||||
return t
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if device is None:
|
||||
device = self._device
|
||||
if allow_gds is None:
|
||||
allow_gds = self._allow_gds
|
||||
meta = self._index.meta(key)
|
||||
return self._file.read_tensor(meta, device, dtype, allow_gds, pin_if_cpu)
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
return self.get_tensor(key)
|
||||
|
||||
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
||||
self._overrides[key] = value
|
||||
self._deleted.discard(key)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if self._index.has(key):
|
||||
self._deleted.add(key)
|
||||
return
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for k in self._index.keys():
|
||||
if k in self._deleted:
|
||||
continue
|
||||
if k in self._overrides:
|
||||
continue
|
||||
yield k
|
||||
for k in self._overrides.keys():
|
||||
yield k
|
||||
|
||||
def __len__(self) -> int:
|
||||
base = len(self._index.keys())
|
||||
return base - len(self._deleted) + len(self._overrides)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
if not isinstance(key, str):
|
||||
return False
|
||||
if key in self._deleted:
|
||||
return False
|
||||
if key in self._overrides:
|
||||
return True
|
||||
return self._index.has(key)
|
||||
|
||||
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
return self._overrides.pop(key)
|
||||
if key in self._deleted:
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
if self._index.has(key):
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
|
||||
def copy(self) -> "StreamStateDict":
|
||||
new = StreamStateDict(self._index, self._file.acquire(), self._device, allow_gds=self._allow_gds)
|
||||
new._overrides = dict(self._overrides)
|
||||
new._deleted = set(self._deleted)
|
||||
return new
|
||||
|
||||
def metadata(self):
|
||||
return self._index.metadata()
|
||||
|
||||
|
||||
class _BaseViewStateDict(MutableMapping):
|
||||
is_stream_state_dict = True
|
||||
|
||||
def __init__(self, base: MutableMapping, mutate_base: bool = False):
|
||||
self._base = base
|
||||
self._mutate_base = mutate_base
|
||||
self._overrides: Dict[str, torch.Tensor] = {}
|
||||
self._deleted: set[str] = set()
|
||||
|
||||
def _resolve_base_key(self, key: str) -> Optional[str]:
|
||||
return key
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
return self._base.keys()
|
||||
|
||||
def get_tensor(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
allow_gds: Optional[bool] = None,
|
||||
pin_if_cpu: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
if device is not None and t.device != device:
|
||||
t = t.to(device=device)
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
_validate_dtype_conversion(t.dtype, dtype)
|
||||
t = t.to(dtype=dtype)
|
||||
return t
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if hasattr(self._base, "get_tensor"):
|
||||
return self._base.get_tensor(
|
||||
base_key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
|
||||
)
|
||||
t = self._base[base_key]
|
||||
if device is not None and t.device != device:
|
||||
t = t.to(device=device)
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
_validate_dtype_conversion(t.dtype, dtype)
|
||||
t = t.to(dtype=dtype)
|
||||
return t
|
||||
|
||||
def meta(self, key: str):
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if hasattr(self._base, "meta"):
|
||||
return self._base.meta(base_key)
|
||||
t = self._base[base_key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
return self.get_tensor(key)
|
||||
|
||||
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
||||
base_key = self._resolve_base_key(key)
|
||||
if self._mutate_base and base_key is not None and base_key in self._base:
|
||||
self._base[base_key] = value
|
||||
else:
|
||||
self._overrides[key] = value
|
||||
self._deleted.discard(key)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if self._mutate_base and base_key in self._base:
|
||||
del self._base[base_key]
|
||||
else:
|
||||
self._deleted.add(key)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for k in self._iter_base_keys():
|
||||
if k in self._deleted:
|
||||
continue
|
||||
yield k
|
||||
for k in self._overrides.keys():
|
||||
yield k
|
||||
|
||||
def __len__(self) -> int:
|
||||
base_keys = list(self._iter_base_keys())
|
||||
return len(base_keys) - len(self._deleted) + len(self._overrides)
|
||||
|
||||
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
return self._overrides.pop(key)
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
if self._mutate_base:
|
||||
try:
|
||||
return self._base.pop(base_key)
|
||||
except KeyError:
|
||||
if default is _MISSING:
|
||||
raise
|
||||
return default
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
|
||||
|
||||
class FilterViewStateDict(_BaseViewStateDict):
|
||||
def __init__(self, base: MutableMapping, predicate, mutate_base: bool = False):
|
||||
super().__init__(base, mutate_base=mutate_base)
|
||||
self._predicate = predicate
|
||||
|
||||
def _resolve_base_key(self, key: str) -> Optional[str]:
|
||||
if self._predicate(key):
|
||||
return key
|
||||
return None
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
for k in self._base.keys():
|
||||
if self._predicate(k):
|
||||
yield k
|
||||
|
||||
|
||||
class PrefixViewStateDict(_BaseViewStateDict):
|
||||
def __init__(self, base: MutableMapping, source_prefix: str, target_prefix: str = "", mutate_base: bool = False):
|
||||
super().__init__(base, mutate_base=mutate_base)
|
||||
self._source_prefix = source_prefix
|
||||
self._target_prefix = target_prefix
|
||||
self._mapping: Dict[str, str] = {}
|
||||
self._reverse: Dict[str, str] = {}
|
||||
for k in base.keys():
|
||||
if not k.startswith(source_prefix):
|
||||
continue
|
||||
view_key = f"{target_prefix}{k[len(source_prefix):]}"
|
||||
self._mapping[k] = view_key
|
||||
self._reverse[view_key] = k
|
||||
|
||||
def _resolve_base_key(self, key: str) -> Optional[str]:
|
||||
return self._reverse.get(key)
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
return self._reverse.keys()
|
||||
|
||||
|
||||
class RenameViewStateDict(_BaseViewStateDict):
|
||||
def __init__(
|
||||
self,
|
||||
base: MutableMapping,
|
||||
replace_prefix: Mapping[str, str],
|
||||
filter_keys: bool = False,
|
||||
mutate_base: bool = False,
|
||||
):
|
||||
super().__init__(base, mutate_base=mutate_base)
|
||||
self._filter_keys = filter_keys
|
||||
self._replace = list(replace_prefix.items())
|
||||
self._mapping: Dict[str, str] = {}
|
||||
self._reverse: Dict[str, str] = {}
|
||||
for k in base.keys():
|
||||
view_key = self._replace_key(k)
|
||||
if view_key is None:
|
||||
continue
|
||||
self._mapping[k] = view_key
|
||||
self._reverse[view_key] = k
|
||||
|
||||
def _replace_key(self, key: str) -> Optional[str]:
|
||||
for rp, dst in self._replace:
|
||||
if key.startswith(rp):
|
||||
return f"{dst}{key[len(rp):]}"
|
||||
if self._filter_keys:
|
||||
return None
|
||||
return key
|
||||
|
||||
def _resolve_base_key(self, key: str) -> Optional[str]:
|
||||
return self._reverse.get(key)
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
return self._reverse.keys()
|
||||
|
||||
|
||||
class MergedStateDict(MutableMapping):
|
||||
is_stream_state_dict = True
|
||||
|
||||
def __init__(self, *mappings: MutableMapping):
|
||||
self._mappings = list(mappings)
|
||||
self._overrides: Dict[str, torch.Tensor] = {}
|
||||
self._deleted: set[str] = set()
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
return self._overrides[key]
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
for mapping in reversed(self._mappings):
|
||||
if key in mapping:
|
||||
if hasattr(mapping, "get_tensor"):
|
||||
return mapping.get_tensor(key)
|
||||
return mapping[key]
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
||||
self._overrides[key] = value
|
||||
self._deleted.discard(key)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if any(key in mapping for mapping in self._mappings):
|
||||
self._deleted.add(key)
|
||||
return
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
seen = set()
|
||||
for mapping in self._mappings:
|
||||
for key in mapping.keys():
|
||||
if key in self._deleted or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
yield key
|
||||
for key in self._overrides.keys():
|
||||
if key not in seen:
|
||||
yield key
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(list(self.__iter__()))
|
||||
|
||||
def meta(self, key: str):
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
for mapping in reversed(self._mappings):
|
||||
if key in mapping:
|
||||
if hasattr(mapping, "meta"):
|
||||
return mapping.meta(key)
|
||||
t = mapping[key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
raise KeyError(key)
|
||||
|
||||
|
||||
class MappedStateDict(_BaseViewStateDict):
|
||||
def __init__(self, base: MutableMapping, key_map: Mapping[str, str], mutate_base: bool = False):
|
||||
super().__init__(base, mutate_base=mutate_base)
|
||||
self._base_to_view = dict(key_map)
|
||||
self._view_to_base = {v: k for k, v in key_map.items()}
|
||||
|
||||
def _resolve_base_key(self, key: str) -> Optional[str]:
|
||||
return self._view_to_base.get(key)
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
return self._view_to_base.keys()
|
||||
129
comfy/sd.py
129
comfy/sd.py
@ -25,6 +25,7 @@ import math
|
||||
import os
|
||||
|
||||
import comfy.utils
|
||||
import comfy.safetensors_stream
|
||||
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
@ -288,7 +289,7 @@ class CLIP:
|
||||
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
|
||||
else:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
@ -346,7 +347,7 @@ class VAE:
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||
elif "taesd_decoder.1.weight" in sd:
|
||||
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||
self.latent_channels = sd_shape(sd, "taesd_decoder.1.weight")[1]
|
||||
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||
self.first_stage_model = StageA()
|
||||
@ -361,25 +362,19 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["encoder.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
|
||||
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["previewer.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "previewer."})
|
||||
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
if sd_shape(sd, 'decoder.conv_in.weight')[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
|
||||
self.downscale_ratio = 32
|
||||
self.upscale_ratio = 32
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
@ -389,9 +384,9 @@ class VAE:
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||
elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
|
||||
elif sd_shape(sd, 'decoder.conv_in.weight')[1] == 32 and len(sd_shape(sd, 'decoder.conv_in.weight')) == 5:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
@ -414,7 +409,7 @@ class VAE:
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
|
||||
if 'decoder.post_quant_conv.weight' in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
|
||||
|
||||
@ -427,7 +422,7 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
|
||||
|
||||
if 'post_quant_conv.weight' in sd:
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
|
||||
else:
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||
@ -462,11 +457,11 @@ class VAE:
|
||||
self.downscale_index_formula = (6, 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
|
||||
tensor_conv1_shape = sd_shape(sd, "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight")
|
||||
version = 0
|
||||
if tensor_conv1.shape[0] == 512:
|
||||
if tensor_conv1_shape[0] == 512:
|
||||
version = 0
|
||||
elif tensor_conv1.shape[0] == 1024:
|
||||
elif tensor_conv1_shape[0] == 1024:
|
||||
version = 1
|
||||
if "encoder.down_blocks.1.conv.conv.bias" in sd:
|
||||
version = 2
|
||||
@ -483,9 +478,9 @@ class VAE:
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
self.downscale_index_formula = (8, 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd_shape(sd, 'decoder.conv_in.conv.weight')[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
|
||||
self.latent_channels = 32
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
@ -509,8 +504,8 @@ class VAE:
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
|
||||
#This is likely to significantly over-estimate with single image or low frame counts as the
|
||||
#implementation is able to completely skip caching. Rework if used as an image only VAE
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
|
||||
@ -543,14 +538,14 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
|
||||
else: # Wan 2.1 VAE
|
||||
dim = sd["decoder.head.0.gamma"].shape[0]
|
||||
dim = sd_shape(sd, "decoder.head.0.gamma")[0]
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||
self.output_channels = sd_shape(sd, "encoder.conv1.weight")[1]
|
||||
self.pad_channel_value = 1.0
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
@ -626,7 +621,7 @@ class VAE:
|
||||
self.working_dtypes = [torch.float32]
|
||||
self.crop_input = False
|
||||
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
|
||||
self.latent_channels = sd["decoder.1.weight"].shape[1]
|
||||
self.latent_channels = sd_shape(sd, "decoder.1.weight")[1]
|
||||
self.latent_dim = 3
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
@ -637,12 +632,12 @@ class VAE:
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.process_output = lambda image: image
|
||||
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
|
||||
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
|
||||
elif self.latent_channels == 32 and sd_shape(sd, "decoder.22.bias")[0] == 12: # lighttae_hv15
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
if comfy.utils.state_dict_meta(sd, "decoder.1.weight").dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
latent_format=comfy.latent_formats.HunyuanVideo
|
||||
else:
|
||||
latent_format=None # lighttaew2_1 doesn't need scaling
|
||||
@ -662,7 +657,7 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
@ -983,9 +978,12 @@ def load_style_model(ckpt_path):
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
model.load_state_dict(model_data)
|
||||
comfy.utils.load_state_dict(model, model_data, strict=True)
|
||||
return StyleModel(model)
|
||||
|
||||
def sd_shape(state_dict, key):
|
||||
return comfy.utils.state_dict_meta(state_dict, key).shape
|
||||
|
||||
class CLIPType(Enum):
|
||||
STABLE_DIFFUSION = 1
|
||||
STABLE_CASCADE = 2
|
||||
@ -1055,16 +1053,16 @@ def detect_te_model(sd):
|
||||
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
|
||||
return TEModel.JINA_CLIP_2
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
weight_shape = sd_shape(sd, "encoder.block.23.layer.1.DenseReluDense.wi_1.weight")
|
||||
if weight_shape[-1] == 4096:
|
||||
return TEModel.T5_XXL
|
||||
elif weight.shape[-1] == 2048:
|
||||
elif weight_shape[-1] == 2048:
|
||||
return TEModel.T5_XL
|
||||
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||
return TEModel.T5_XXL_OLD
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
|
||||
if weight.shape[0] == 384:
|
||||
weight_shape = sd_shape(sd, 'encoder.block.0.layer.0.SelfAttention.k.weight')
|
||||
if weight_shape[0] == 384:
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
@ -1074,19 +1072,19 @@ def detect_te_model(sd):
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
if weight.shape[0] == 256:
|
||||
weight_shape = sd_shape(sd, 'model.layers.0.self_attn.k_proj.bias')
|
||||
if weight_shape[0] == 256:
|
||||
return TEModel.QWEN25_3B
|
||||
if weight.shape[0] == 512:
|
||||
if weight_shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
weight_shape = sd_shape(sd, 'model.layers.0.post_attention_layernorm.weight')
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
if weight.shape[0] == 2560:
|
||||
if weight_shape[0] == 2560:
|
||||
return TEModel.QWEN3_4B
|
||||
elif weight.shape[0] == 2048:
|
||||
elif weight_shape[0] == 2048:
|
||||
return TEModel.QWEN3_2B
|
||||
if weight.shape[0] == 5120:
|
||||
if weight_shape[0] == 5120:
|
||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.MISTRAL3_24B
|
||||
else:
|
||||
@ -1415,19 +1413,29 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
scaled_fp8_list.append(k[:-len("scaled_fp8")])
|
||||
|
||||
if len(scaled_fp8_list) > 0:
|
||||
out_sd = {}
|
||||
for k in sd:
|
||||
skip = False
|
||||
if comfy.utils.is_stream_state_dict(sd):
|
||||
def _keep_key(k, prefixes=tuple(scaled_fp8_list)):
|
||||
return not any(k.startswith(pref) for pref in prefixes)
|
||||
out_sd = comfy.safetensors_stream.FilterViewStateDict(sd, _keep_key, mutate_base=False)
|
||||
merged = out_sd
|
||||
for pref in scaled_fp8_list:
|
||||
skip = skip or k.startswith(pref)
|
||||
if not skip:
|
||||
out_sd[k] = sd[k]
|
||||
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||
merged = comfy.safetensors_stream.MergedStateDict(merged, quant_sd)
|
||||
sd = merged
|
||||
else:
|
||||
out_sd = {}
|
||||
for k in sd:
|
||||
skip = False
|
||||
for pref in scaled_fp8_list:
|
||||
skip = skip or k.startswith(pref)
|
||||
if not skip:
|
||||
out_sd[k] = sd[k]
|
||||
|
||||
for pref in scaled_fp8_list:
|
||||
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||
for k in quant_sd:
|
||||
out_sd[k] = quant_sd[k]
|
||||
sd = out_sd
|
||||
for pref in scaled_fp8_list:
|
||||
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||
for k in quant_sd:
|
||||
out_sd[k] = quant_sd[k]
|
||||
sd = out_sd
|
||||
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
if clip_target is not None:
|
||||
@ -1505,12 +1513,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
|
||||
|
||||
new_sd = {}
|
||||
for k in diffusers_keys:
|
||||
if k in sd:
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
if comfy.utils.is_stream_state_dict(sd):
|
||||
new_sd = comfy.safetensors_stream.MappedStateDict(sd, diffusers_keys)
|
||||
else:
|
||||
new_sd = {}
|
||||
for k in diffusers_keys:
|
||||
if k in sd:
|
||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||
else:
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
|
||||
@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
return self(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
@ -430,8 +430,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
|
||||
try:
|
||||
if embed_path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||
embed = comfy.utils.load_torch_file(embed_path, safe_load=True)
|
||||
else:
|
||||
try:
|
||||
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||
|
||||
@ -56,9 +56,9 @@ class TAESD(nn.Module):
|
||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||
if encoder_path is not None:
|
||||
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
|
||||
if decoder_path is not None:
|
||||
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
|
||||
@ -116,7 +116,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
return self.load_state_dict(sdo, strict=False)
|
||||
return comfy.utils.load_state_dict(self, sdo, strict=False)
|
||||
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
|
||||
195
comfy/utils.py
195
comfy/utils.py
@ -26,10 +26,13 @@ import numpy as np
|
||||
from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
from types import SimpleNamespace
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
from . import safetensors_stream
|
||||
import comfy.disk_weights
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@ -61,15 +64,9 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
metadata = None
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
sd = safetensors_stream.StreamStateDict.from_file(ckpt, device=device)
|
||||
if return_metadata:
|
||||
metadata = sd.metadata()
|
||||
except Exception as e:
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
@ -110,16 +107,16 @@ def calculate_parameters(sd, prefix=""):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
params += w.nelement()
|
||||
meta = state_dict_meta(sd, k)
|
||||
params += meta.numel
|
||||
return params
|
||||
|
||||
def weight_dtype(sd, prefix=""):
|
||||
dtypes = {}
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
w = sd[k]
|
||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
||||
meta = state_dict_meta(sd, k)
|
||||
dtypes[meta.dtype] = dtypes.get(meta.dtype, 0) + meta.numel
|
||||
|
||||
if len(dtypes) == 0:
|
||||
return None
|
||||
@ -133,6 +130,13 @@ def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
return state_dict
|
||||
|
||||
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||
if is_stream_state_dict(state_dict):
|
||||
return safetensors_stream.RenameViewStateDict(
|
||||
state_dict,
|
||||
replace_prefix,
|
||||
filter_keys=filter_keys,
|
||||
mutate_base=not filter_keys,
|
||||
)
|
||||
if filter_keys:
|
||||
out = {}
|
||||
else:
|
||||
@ -145,6 +149,77 @@ def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||
return out
|
||||
|
||||
|
||||
def is_stream_state_dict(state_dict) -> bool:
|
||||
return getattr(state_dict, "is_stream_state_dict", False)
|
||||
|
||||
|
||||
def state_dict_meta(state_dict, key):
|
||||
if hasattr(state_dict, "meta"):
|
||||
return state_dict.meta(key)
|
||||
w = state_dict[key]
|
||||
numel = w.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=w.dtype,
|
||||
shape=tuple(w.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * w.element_size(),
|
||||
)
|
||||
|
||||
|
||||
def load_state_dict(model, state_dict, strict=False, assign=False):
|
||||
if is_stream_state_dict(state_dict):
|
||||
comfy.disk_weights.register_module_weights(model, state_dict)
|
||||
comfy.disk_weights.attach_disk_weight_hooks(model)
|
||||
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
|
||||
return missing, unexpected
|
||||
return model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def stream_load_state_dict(model, state_dict, strict=False, assign=False):
|
||||
if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
|
||||
state_dict = state_dict.copy()
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
|
||||
def load(module, local_state_dict, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
if assign:
|
||||
local_metadata["assign_to_params_buffers"] = assign
|
||||
module._load_from_state_dict(
|
||||
local_state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
True,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
child_prefix = f"{prefix}{name}."
|
||||
child_state_dict = safetensors_stream.FilterViewStateDict(
|
||||
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
|
||||
)
|
||||
load(child, child_state_dict, child_prefix)
|
||||
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
for hook in module._load_state_dict_post_hooks.values():
|
||||
out = hook(module, incompatible)
|
||||
if out is not None:
|
||||
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
|
||||
|
||||
load(model, state_dict)
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
return missing_keys, unexpected_keys
|
||||
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
@ -1217,46 +1292,82 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
|
||||
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
||||
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
||||
if is_stream_state_dict(state_dict):
|
||||
scaled_meta = state_dict_meta(state_dict, scaled_fp8_key)
|
||||
scaled_fp8_dtype = scaled_meta.dtype
|
||||
scaled_fp8_weight_nelements = scaled_meta.numel
|
||||
else:
|
||||
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
||||
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
||||
scaled_fp8_weight_nelements = scaled_fp8_weight.nelement()
|
||||
if scaled_fp8_dtype == torch.float32:
|
||||
scaled_fp8_dtype = torch.float8_e4m3fn
|
||||
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
if scaled_fp8_weight_nelements == 2:
|
||||
full_precision_matrix_mult = True
|
||||
else:
|
||||
full_precision_matrix_mult = False
|
||||
|
||||
out_sd = {}
|
||||
layers = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k == scaled_fp8_key:
|
||||
continue
|
||||
if not k.startswith(model_prefix):
|
||||
out_sd[k] = state_dict[k]
|
||||
continue
|
||||
k_out = k
|
||||
w = state_dict.pop(k)
|
||||
layer = None
|
||||
if k_out.endswith(".scale_weight"):
|
||||
layer = k_out[:-len(".scale_weight")]
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
if k_out.endswith(".scale_input"):
|
||||
layer = k_out[:-len(".scale_input")]
|
||||
k_out = "{}.input_scale".format(layer)
|
||||
if w.item() == 1.0:
|
||||
if is_stream_state_dict(state_dict):
|
||||
key_map = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k == scaled_fp8_key:
|
||||
continue
|
||||
if not k.startswith(model_prefix):
|
||||
key_map[k] = k
|
||||
continue
|
||||
k_out = k
|
||||
layer = None
|
||||
if k_out.endswith(".scale_weight"):
|
||||
layer = k_out[:-len(".scale_weight")]
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
out_sd[k_out] = w
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
state_dict = out_sd
|
||||
if k_out.endswith(".scale_input"):
|
||||
layer = k_out[:-len(".scale_input")]
|
||||
k_out = "{}.input_scale".format(layer)
|
||||
scale_val = state_dict[k]
|
||||
if scale_val.item() == 1.0:
|
||||
continue
|
||||
|
||||
key_map[k] = k_out
|
||||
state_dict = safetensors_stream.MappedStateDict(state_dict, key_map)
|
||||
else:
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k == scaled_fp8_key:
|
||||
continue
|
||||
if not k.startswith(model_prefix):
|
||||
out_sd[k] = state_dict[k]
|
||||
continue
|
||||
k_out = k
|
||||
w = state_dict.pop(k)
|
||||
layer = None
|
||||
if k_out.endswith(".scale_weight"):
|
||||
layer = k_out[:-len(".scale_weight")]
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
if k_out.endswith(".scale_input"):
|
||||
layer = k_out[:-len(".scale_input")]
|
||||
k_out = "{}.input_scale".format(layer)
|
||||
if w.item() == 1.0:
|
||||
continue
|
||||
|
||||
out_sd[k_out] = w
|
||||
|
||||
state_dict = out_sd
|
||||
quant_metadata = {"layers": layers}
|
||||
else:
|
||||
quant_metadata = json.loads(metadata["_quantization_metadata"])
|
||||
|
||||
3
nodes.py
3
nodes.py
@ -17,7 +17,6 @@ from PIL import Image, ImageOps, ImageSequence
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
|
||||
@ -525,7 +524,7 @@ class LoadLatent:
|
||||
|
||||
def load(self, latent):
|
||||
latent_path = folder_paths.get_annotated_filepath(latent)
|
||||
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
||||
latent = comfy.utils.load_torch_file(latent_path, safe_load=True)
|
||||
multiplier = 1.0
|
||||
if "latent_format_version_0" not in latent:
|
||||
multiplier = 1.0 / 0.18215
|
||||
|
||||
@ -11,6 +11,7 @@ transformers>=4.50.3
|
||||
tokenizers>=0.13.3
|
||||
sentencepiece
|
||||
safetensors>=0.4.2
|
||||
fastsafetensors @ https://github.com/foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip
|
||||
aiohttp>=3.11.8
|
||||
yarl>=1.18.0
|
||||
pyyaml
|
||||
|
||||
146
tests-unit/utils/safetensors_stream_test.py
Normal file
146
tests-unit/utils/safetensors_stream_test.py
Normal file
@ -0,0 +1,146 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
def _write_safetensors(tmp_path, tensors):
|
||||
import safetensors.torch
|
||||
path = os.path.join(tmp_path, "test.safetensors")
|
||||
safetensors.torch.save_file(tensors, path)
|
||||
return path
|
||||
|
||||
|
||||
def test_stream_state_dict_meta_is_lazy(tmp_path, monkeypatch):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
calls = []
|
||||
|
||||
original = sd._file.read_tensor
|
||||
|
||||
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
||||
calls.append(meta)
|
||||
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
||||
|
||||
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
||||
meta = sd.meta("a")
|
||||
assert meta.shape == (2, 3)
|
||||
assert meta.dtype == torch.float32
|
||||
assert meta.numel == 6
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_stream_state_dict_getitem_loads_single_tensor(tmp_path, monkeypatch):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
calls = []
|
||||
|
||||
original = sd._file.read_tensor
|
||||
|
||||
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
||||
calls.append(meta)
|
||||
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
||||
|
||||
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
||||
_ = sd["a"]
|
||||
assert len(calls) == 1
|
||||
assert calls[0].shape == (2, 3)
|
||||
|
||||
|
||||
def test_stream_views_do_not_materialize(tmp_path, monkeypatch):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
path = _write_safetensors(tmp_path, {"prefix.a": torch.zeros((2, 3)), "other": torch.ones((4,))})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
calls = []
|
||||
|
||||
original = sd._file.read_tensor
|
||||
|
||||
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
||||
calls.append(meta)
|
||||
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
||||
|
||||
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
||||
view = comfy.utils.state_dict_prefix_replace(sd, {"prefix.": ""}, filter_keys=True)
|
||||
_ = list(view.keys())
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_stream_load_rss_small(tmp_path):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
psutil = pytest.importorskip("psutil")
|
||||
process = psutil.Process()
|
||||
size_elems = 4_000_000 # ~16MB float32
|
||||
tensor = torch.zeros((size_elems,), dtype=torch.float32)
|
||||
path = _write_safetensors(tmp_path, {"big": tensor})
|
||||
rss_before = process.memory_info().rss
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
rss_after = process.memory_info().rss
|
||||
expected_size = tensor.numel() * tensor.element_size()
|
||||
assert (rss_after - rss_before) < expected_size
|
||||
_ = sd.meta("big")
|
||||
|
||||
|
||||
def test_gds_path_errors_without_support(tmp_path, monkeypatch):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
device = torch.device("cuda")
|
||||
|
||||
if importlib.util.find_spec("fastsafetensors") is None:
|
||||
fst = None
|
||||
else:
|
||||
fst = importlib.import_module("fastsafetensors")
|
||||
|
||||
gds_available = False
|
||||
if fst is not None and torch.cuda.is_available():
|
||||
gds_supported = fst.cpp.is_gds_supported(torch.cuda.current_device())
|
||||
gds_available = bool(fst.cpp.is_cufile_found()) and gds_supported == 1
|
||||
|
||||
if not gds_available:
|
||||
with pytest.raises(RuntimeError, match="GPUDirect requested"):
|
||||
sd.get_tensor("a", device=device, allow_gds=True)
|
||||
else:
|
||||
def fail_nogds(*args, **kwargs):
|
||||
raise AssertionError("nogds path used during GDS request")
|
||||
|
||||
monkeypatch.setattr(sd._file, "_read_tensor_nogds", fail_nogds)
|
||||
t = sd.get_tensor("a", device=device, allow_gds=True)
|
||||
assert t.device.type == "cuda"
|
||||
|
||||
|
||||
def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path):
|
||||
if torch is None:
|
||||
pytest.skip("torch not installed")
|
||||
import comfy.utils
|
||||
import comfy.disk_weights
|
||||
|
||||
prev_cache = comfy.disk_weights.CACHE.max_bytes
|
||||
prev_gds = comfy.disk_weights.ALLOW_GDS
|
||||
prev_pin = comfy.disk_weights.PIN_IF_CPU
|
||||
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
|
||||
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
|
||||
|
||||
try:
|
||||
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
model = torch.nn.Linear(4, 4, bias=True)
|
||||
comfy.utils.load_state_dict(model, sd, strict=False)
|
||||
assert model.weight.device.type == "cpu"
|
||||
assert model.weight.device.type != "meta"
|
||||
finally:
|
||||
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
||||
Loading…
Reference in New Issue
Block a user