mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 09:10:50 +08:00
935 lines
32 KiB
Python
935 lines
32 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
import ctypes
|
|
import importlib
|
|
import importlib.util
|
|
import os
|
|
import threading
|
|
from dataclasses import dataclass
|
|
from types import SimpleNamespace
|
|
from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple
|
|
|
|
import torch
|
|
|
|
|
|
_FST_MODULE = None
|
|
_FST_LOCK = threading.Lock()
|
|
_FST_LOADED = False
|
|
_GDS_INITIALIZED = False
|
|
_MISSING = object()
|
|
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
|
|
|
|
|
|
def _require_fastsafetensors():
|
|
global _FST_MODULE
|
|
with _FST_LOCK:
|
|
if _FST_MODULE is None:
|
|
if importlib.util.find_spec("fastsafetensors") is None:
|
|
raise ImportError(
|
|
"fastsafetensors is required for safetensors streaming. "
|
|
"Install it with: pip install 'fastsafetensors @ https://github.com/"
|
|
"foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip'"
|
|
)
|
|
_FST_MODULE = importlib.import_module("fastsafetensors")
|
|
return _FST_MODULE
|
|
|
|
|
|
def _init_fastsafetensors_lib():
|
|
global _FST_LOADED
|
|
fst = _require_fastsafetensors()
|
|
if not _FST_LOADED:
|
|
fst.cpp.load_library_functions()
|
|
_FST_LOADED = True
|
|
return fst
|
|
|
|
|
|
def _init_gds():
|
|
global _GDS_INITIALIZED
|
|
fst = _init_fastsafetensors_lib()
|
|
if not _GDS_INITIALIZED:
|
|
if fst.cpp.init_gds() != 0:
|
|
raise RuntimeError("fastsafetensors init_gds() failed")
|
|
_GDS_INITIALIZED = True
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TensorMeta:
|
|
dtype: torch.dtype
|
|
shape: Tuple[int, ...]
|
|
numel: int
|
|
nbytes: int
|
|
data_offsets: Tuple[int, int]
|
|
filename: str
|
|
fst_dtype: object
|
|
strides: Tuple[int, ...]
|
|
|
|
|
|
class SafeTensorIndex:
|
|
def __init__(self, filename: str):
|
|
fst = _init_fastsafetensors_lib()
|
|
framework = fst.frameworks.get_framework_op("pytorch")
|
|
metadata = fst.common.SafeTensorsMetadata.from_file(filename, framework)
|
|
self._filename = filename
|
|
self._metadata = metadata
|
|
self._framework = framework
|
|
from fastsafetensors.frameworks import _torch as fst_torch
|
|
self._dtype_map = fst_torch.dtype_convert
|
|
self._tensor_meta: Dict[str, TensorMeta] = {}
|
|
for key, frame in metadata.tensors.items():
|
|
torch_dtype = self._dtype_map.get(frame.dtype, None)
|
|
if torch_dtype is None:
|
|
raise ValueError(f"Unsupported safetensors dtype {frame.dtype} in {filename}")
|
|
numel = 1
|
|
for s in frame.shape:
|
|
numel *= s
|
|
nbytes = numel * framework.get_dtype_size(frame.dtype)
|
|
self._tensor_meta[key] = TensorMeta(
|
|
dtype=torch_dtype,
|
|
shape=tuple(frame.shape),
|
|
numel=numel,
|
|
nbytes=nbytes,
|
|
data_offsets=(frame.data_offsets[0], frame.data_offsets[1]),
|
|
filename=filename,
|
|
fst_dtype=frame.dtype,
|
|
strides=tuple(frame.strides),
|
|
)
|
|
|
|
def keys(self) -> Iterable[str]:
|
|
return self._tensor_meta.keys()
|
|
|
|
def has(self, key: str) -> bool:
|
|
return key in self._tensor_meta
|
|
|
|
def meta(self, key: str) -> TensorMeta:
|
|
return self._tensor_meta[key]
|
|
|
|
def metadata(self):
|
|
return self._metadata.metadata
|
|
|
|
@property
|
|
def header_length(self) -> int:
|
|
return self._metadata.header_length
|
|
|
|
@property
|
|
def size_bytes(self) -> int:
|
|
return self._metadata.size_bytes
|
|
|
|
|
|
class _SafeTensorFile:
|
|
def __init__(self, filename: str, index: SafeTensorIndex):
|
|
self.filename = filename
|
|
self.index = index
|
|
self._fd: Optional[int] = None
|
|
self._gds_handle = None
|
|
self._gds_reader = None
|
|
self._nogds_reader = None
|
|
self._refcount = 1
|
|
|
|
def acquire(self) -> "_SafeTensorFile":
|
|
self._refcount += 1
|
|
return self
|
|
|
|
def release(self):
|
|
self._refcount -= 1
|
|
if self._refcount <= 0:
|
|
self.close()
|
|
|
|
def close(self):
|
|
if self._fd is not None:
|
|
os.close(self._fd)
|
|
self._fd = None
|
|
self._gds_handle = None
|
|
|
|
def _ensure_fd(self) -> int:
|
|
if self._fd is None:
|
|
self._fd = os.open(self.filename, os.O_RDONLY, 0o644)
|
|
return self._fd
|
|
|
|
def _ensure_nogds_reader(self, use_cuda: bool):
|
|
fst = _init_fastsafetensors_lib()
|
|
if self._nogds_reader is None:
|
|
self._nogds_reader = fst.cpp.nogds_file_reader(
|
|
False, 16 * 1024, 16, use_cuda
|
|
)
|
|
return self._nogds_reader
|
|
|
|
def _ensure_gds_reader(self, use_cuda: bool):
|
|
fst = _init_fastsafetensors_lib()
|
|
if self._gds_reader is None:
|
|
self._gds_reader = fst.cpp.gds_file_reader(16, use_cuda)
|
|
return self._gds_reader
|
|
|
|
def _ensure_gds_handle(self, use_cuda: bool):
|
|
if self._gds_handle is None:
|
|
fst = _init_fastsafetensors_lib()
|
|
framework = fst.frameworks.get_framework_op("pytorch")
|
|
o_direct = _get_gds_o_direct(framework)
|
|
self._gds_handle = fst.cpp.gds_file_handle(self.filename, o_direct, use_cuda)
|
|
return self._gds_handle
|
|
|
|
def read_tensor(
|
|
self,
|
|
meta: TensorMeta,
|
|
device: torch.device,
|
|
dtype: Optional[torch.dtype],
|
|
allow_gds: bool,
|
|
pin_if_cpu: bool,
|
|
) -> torch.Tensor:
|
|
fst = _init_fastsafetensors_lib()
|
|
framework = fst.frameworks.get_framework_op("pytorch")
|
|
device_is_cuda = device.type == "cuda"
|
|
if device_is_cuda and allow_gds:
|
|
_ensure_gds_ready(device)
|
|
tensor = self._read_tensor_gds(
|
|
fst, framework, meta, device, dtype
|
|
)
|
|
return tensor
|
|
|
|
cpu_tensor = self._read_tensor_nogds(
|
|
fst, framework, meta, torch.device("cpu"), dtype
|
|
)
|
|
if device_is_cuda:
|
|
if pin_if_cpu:
|
|
cpu_tensor = cpu_tensor.pin_memory()
|
|
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
|
|
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
|
|
return gpu_tensor
|
|
return cpu_tensor
|
|
|
|
def _aligned_range(self, abs_start: int, length: int) -> Tuple[int, int, int]:
|
|
fst = _init_fastsafetensors_lib()
|
|
align = fst.cpp.get_alignment_size()
|
|
aligned_offset = (abs_start // align) * align
|
|
head = abs_start - aligned_offset
|
|
aligned_length = length + head
|
|
tail = aligned_length % align
|
|
if tail:
|
|
aligned_length += align - tail
|
|
return aligned_offset, aligned_length, head
|
|
|
|
def _read_tensor_nogds(
|
|
self,
|
|
fst,
|
|
framework,
|
|
meta: TensorMeta,
|
|
device: torch.device,
|
|
dtype: Optional[torch.dtype],
|
|
) -> torch.Tensor:
|
|
fd = self._ensure_fd()
|
|
reader = self._ensure_nogds_reader(use_cuda=False)
|
|
abs_start = self.index.header_length + meta.data_offsets[0]
|
|
length = meta.data_offsets[1] - meta.data_offsets[0]
|
|
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
|
|
chunk_bytes = max(1, chunk_bytes)
|
|
ptr_align = framework.get_device_ptr_align()
|
|
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu")
|
|
buffer_length = 0
|
|
buf_ptr = None
|
|
gbuf = None
|
|
try:
|
|
chunk_offset = 0
|
|
while chunk_offset < length:
|
|
chunk_len = min(length - chunk_offset, chunk_bytes)
|
|
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
|
|
needed = aligned_length + ptr_align
|
|
if buf_ptr is None or needed > buffer_length:
|
|
if buf_ptr is not None:
|
|
fst.cpp.cpu_free(buf_ptr)
|
|
buffer_length = needed
|
|
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
|
|
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
|
|
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
|
|
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
|
|
if reader.wait_read(req) < 0:
|
|
raise RuntimeError("nogds_file_reader read failed")
|
|
src_ptr = gbuf.get_base_address() + ptr_off + head
|
|
dest_ptr = dest_tensor.data_ptr() + chunk_offset
|
|
ctypes.memmove(dest_ptr, src_ptr, chunk_len)
|
|
chunk_offset += chunk_len
|
|
except Exception:
|
|
if buf_ptr is not None:
|
|
fst.cpp.cpu_free(buf_ptr)
|
|
raise
|
|
if buf_ptr is not None:
|
|
fst.cpp.cpu_free(buf_ptr)
|
|
if dtype is not None and dtype != dest_tensor.dtype:
|
|
_validate_dtype_conversion(dest_tensor.dtype, dtype)
|
|
dest_tensor = dest_tensor.to(dtype=dtype)
|
|
return dest_tensor
|
|
|
|
def _read_tensor_gds(
|
|
self,
|
|
fst,
|
|
framework,
|
|
meta: TensorMeta,
|
|
device: torch.device,
|
|
dtype: Optional[torch.dtype],
|
|
) -> torch.Tensor:
|
|
reader = self._ensure_gds_reader(use_cuda=True)
|
|
handle = self._ensure_gds_handle(use_cuda=True)
|
|
abs_start = self.index.header_length + meta.data_offsets[0]
|
|
length = meta.data_offsets[1] - meta.data_offsets[0]
|
|
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
|
|
ptr_align = framework.get_device_ptr_align()
|
|
buffer_length = aligned_length + ptr_align
|
|
fst_device = _fst_device_from_torch(fst, device)
|
|
gbuf = framework.alloc_tensor_memory(buffer_length, fst_device)
|
|
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
|
|
file_length = self.index.size_bytes
|
|
req = reader.submit_read(
|
|
handle, gbuf, aligned_offset, aligned_length, ptr_off, file_length
|
|
)
|
|
if reader.wait_read(req) < 0:
|
|
framework.free_tensor_memory(gbuf, fst_device)
|
|
raise RuntimeError("gds_file_reader read failed")
|
|
owner = _BufferOwner(lambda: framework.free_tensor_memory(gbuf, fst_device))
|
|
tensor = _dlpack_tensor_from_buffer(
|
|
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
|
|
)
|
|
if dtype is not None and dtype != tensor.dtype:
|
|
_validate_dtype_conversion(tensor.dtype, dtype)
|
|
tensor = tensor.to(dtype=dtype)
|
|
return tensor
|
|
|
|
|
|
def _fst_device_from_torch(fst, device: torch.device):
|
|
if device.type == "cuda" and device.index is not None:
|
|
return fst.st_types.Device.from_str(f"cuda:{device.index}")
|
|
return fst.st_types.Device.from_str(device.type)
|
|
|
|
|
|
class _BufferOwner:
|
|
def __init__(self, free_fn):
|
|
self._free_fn = free_fn
|
|
|
|
def __del__(self):
|
|
try:
|
|
self._free_fn()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _dlpack_tensor_from_buffer(
|
|
fst,
|
|
framework,
|
|
ptr: int,
|
|
meta: TensorMeta,
|
|
device: torch.device,
|
|
owner: Optional[_BufferOwner],
|
|
) -> torch.Tensor:
|
|
disk_dtype = framework.as_workaround_dtype(meta.fst_dtype)
|
|
dev = _fst_device_from_torch(fst, device)
|
|
dl_tensor = fst.dlpack.from_cuda_buffer(ptr, list(meta.shape), list(meta.strides), disk_dtype, dev)
|
|
torch_tensor = framework.from_dlpack(dl_tensor, dev, disk_dtype).real_tensor
|
|
if disk_dtype != meta.fst_dtype:
|
|
torch_tensor = torch_tensor.view(meta.dtype)
|
|
if owner is not None:
|
|
torch_tensor._comfy_disk_buffer_owner = owner
|
|
return torch_tensor
|
|
|
|
|
|
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
|
|
if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
|
|
raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
|
|
|
|
|
|
def _get_gds_o_direct(framework) -> bool:
|
|
cuda_ver = framework.get_cuda_ver()
|
|
if cuda_ver and cuda_ver != "0.0":
|
|
ver_parts = cuda_ver.split("-", 1)
|
|
if len(ver_parts) == 2:
|
|
cudavers = list(map(int, ver_parts[1].split(".")))
|
|
if ver_parts[0] == "cuda":
|
|
return not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2))
|
|
return True
|
|
return True
|
|
|
|
|
|
def _ensure_gds_ready(device: torch.device):
|
|
fst = _init_fastsafetensors_lib()
|
|
if not fst.common.is_gpu_found():
|
|
raise RuntimeError(
|
|
"GPUDirect requested but GPU runtime library is missing. "
|
|
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
|
)
|
|
gds_supported = fst.cpp.is_gds_supported(device.index if device.index is not None else 0)
|
|
if gds_supported < 0:
|
|
raise RuntimeError(
|
|
"GPUDirect requested but is_gds_supported() failed. "
|
|
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
|
)
|
|
if not fst.cpp.is_cufile_found():
|
|
raise RuntimeError(
|
|
"GPUDirect requested but libcufile is missing. "
|
|
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
|
)
|
|
if gds_supported == 0:
|
|
raise RuntimeError(
|
|
"GPUDirect requested but GDS is unsupported on this platform. "
|
|
"Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
|
|
)
|
|
_init_gds()
|
|
|
|
|
|
class StreamStateDict(collections.abc.MutableMapping):
|
|
is_stream_state_dict = True
|
|
|
|
def __init__(
|
|
self,
|
|
index: SafeTensorIndex,
|
|
file: _SafeTensorFile,
|
|
device: torch.device,
|
|
allow_gds: bool = False,
|
|
):
|
|
self._index = index
|
|
self._file = file
|
|
self._device = device
|
|
self._allow_gds = allow_gds
|
|
self._overrides: Dict[str, torch.Tensor] = {}
|
|
self._deleted: set[str] = set()
|
|
|
|
@classmethod
|
|
def from_file(cls, filename: str, device: torch.device, allow_gds: bool = False) -> "StreamStateDict":
|
|
index = SafeTensorIndex(filename)
|
|
file = _SafeTensorFile(filename, index)
|
|
return cls(index, file, device, allow_gds=allow_gds)
|
|
|
|
def close(self):
|
|
if self._file is not None:
|
|
self._file.release()
|
|
self._file = None
|
|
|
|
def __del__(self):
|
|
try:
|
|
self.close()
|
|
except Exception:
|
|
pass
|
|
|
|
def meta(self, key: str) -> TensorMeta:
|
|
if key in self._overrides:
|
|
t = self._overrides[key]
|
|
numel = t.numel()
|
|
return TensorMeta(
|
|
dtype=t.dtype,
|
|
shape=tuple(t.shape),
|
|
numel=numel,
|
|
nbytes=numel * t.element_size(),
|
|
data_offsets=(0, numel * t.element_size()),
|
|
filename="<override>",
|
|
fst_dtype=None,
|
|
strides=tuple(t.stride()),
|
|
)
|
|
if key in self._deleted:
|
|
raise KeyError(key)
|
|
return self._index.meta(key)
|
|
|
|
def get_tensor(
|
|
self,
|
|
key: str,
|
|
*,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
allow_gds: Optional[bool] = None,
|
|
pin_if_cpu: bool = False,
|
|
) -> torch.Tensor:
|
|
if key in self._overrides:
|
|
t = self._overrides[key]
|
|
if device is not None and t.device != device:
|
|
t = t.to(device=device)
|
|
if dtype is not None and t.dtype != dtype:
|
|
_validate_dtype_conversion(t.dtype, dtype)
|
|
t = t.to(dtype=dtype)
|
|
return t
|
|
if key in self._deleted:
|
|
raise KeyError(key)
|
|
if device is None:
|
|
device = self._device
|
|
if device.type == "meta":
|
|
meta = self._index.meta(key)
|
|
target_dtype = dtype or meta.dtype
|
|
if dtype is not None and dtype != meta.dtype:
|
|
_validate_dtype_conversion(meta.dtype, dtype)
|
|
return torch.empty(meta.shape, dtype=target_dtype, device="meta")
|
|
if allow_gds is None:
|
|
allow_gds = self._allow_gds
|
|
meta = self._index.meta(key)
|
|
return self._file.read_tensor(meta, device, dtype, allow_gds, pin_if_cpu)
|
|
|
|
def __getitem__(self, key: str) -> torch.Tensor:
|
|
return self.get_tensor(key)
|
|
|
|
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
|
self._overrides[key] = value
|
|
self._deleted.discard(key)
|
|
|
|
def __delitem__(self, key: str) -> None:
|
|
if key in self._overrides:
|
|
del self._overrides[key]
|
|
return
|
|
if key in self._deleted:
|
|
raise KeyError(key)
|
|
if self._index.has(key):
|
|
self._deleted.add(key)
|
|
return
|
|
raise KeyError(key)
|
|
|
|
def __iter__(self) -> Iterator[str]:
|
|
for k in self._index.keys():
|
|
if k in self._deleted:
|
|
continue
|
|
if k in self._overrides:
|
|
continue
|
|
yield k
|
|
for k in self._overrides.keys():
|
|
yield k
|
|
|
|
def __len__(self) -> int:
|
|
base = len(self._index.keys())
|
|
return base - len(self._deleted) + len(self._overrides)
|
|
|
|
def __contains__(self, key: object) -> bool:
|
|
if not isinstance(key, str):
|
|
return False
|
|
if key in self._deleted:
|
|
return False
|
|
if key in self._overrides:
|
|
return True
|
|
return self._index.has(key)
|
|
|
|
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
|
if key in self._overrides:
|
|
return self._overrides.pop(key)
|
|
if key in self._deleted:
|
|
if default is _MISSING:
|
|
raise KeyError(key)
|
|
return default
|
|
if self._index.has(key):
|
|
self._deleted.add(key)
|
|
return self.get_tensor(key)
|
|
if default is _MISSING:
|
|
raise KeyError(key)
|
|
return default
|
|
|
|
def copy(self) -> "StreamStateDict":
|
|
new = StreamStateDict(self._index, self._file.acquire(), self._device, allow_gds=self._allow_gds)
|
|
new._overrides = dict(self._overrides)
|
|
new._deleted = set(self._deleted)
|
|
return new
|
|
|
|
def metadata(self):
|
|
return self._index.metadata()
|
|
|
|
|
|
class _BaseViewStateDict(MutableMapping):
|
|
is_stream_state_dict = True
|
|
|
|
def __init__(self, base: MutableMapping, mutate_base: bool = False):
|
|
self._base = base
|
|
self._mutate_base = mutate_base
|
|
self._overrides: Dict[str, torch.Tensor] = {}
|
|
self._deleted: set[str] = set()
|
|
|
|
def _resolve_base_key(self, key: str) -> Optional[str]:
|
|
return key
|
|
|
|
def _iter_base_keys(self) -> Iterable[str]:
|
|
return self._base.keys()
|
|
|
|
def get_tensor(
|
|
self,
|
|
key: str,
|
|
*,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
allow_gds: Optional[bool] = None,
|
|
pin_if_cpu: bool = False,
|
|
) -> torch.Tensor:
|
|
if key in self._overrides:
|
|
t = self._overrides[key]
|
|
if device is not None and t.device != device:
|
|
t = t.to(device=device)
|
|
if dtype is not None and t.dtype != dtype:
|
|
_validate_dtype_conversion(t.dtype, dtype)
|
|
t = t.to(dtype=dtype)
|
|
return t
|
|
base_key = self._resolve_base_key(key)
|
|
if base_key is None or key in self._deleted:
|
|
raise KeyError(key)
|
|
if hasattr(self._base, "get_tensor"):
|
|
return self._base.get_tensor(
|
|
base_key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
|
|
)
|
|
t = self._base[base_key]
|
|
if device is not None and t.device != device:
|
|
t = t.to(device=device)
|
|
if dtype is not None and t.dtype != dtype:
|
|
_validate_dtype_conversion(t.dtype, dtype)
|
|
t = t.to(dtype=dtype)
|
|
return t
|
|
|
|
def __getitem__(self, key: str) -> torch.Tensor:
|
|
return self.get_tensor(key)
|
|
|
|
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
|
base_key = self._resolve_base_key(key)
|
|
if self._mutate_base and base_key is not None and base_key in self._base:
|
|
self._base[base_key] = value
|
|
else:
|
|
self._overrides[key] = value
|
|
self._deleted.discard(key)
|
|
|
|
def __delitem__(self, key: str) -> None:
|
|
if key in self._overrides:
|
|
del self._overrides[key]
|
|
return
|
|
base_key = self._resolve_base_key(key)
|
|
if base_key is None or key in self._deleted:
|
|
raise KeyError(key)
|
|
if self._mutate_base and base_key in self._base:
|
|
del self._base[base_key]
|
|
else:
|
|
self._deleted.add(key)
|
|
|
|
def __iter__(self) -> Iterator[str]:
|
|
for k in self._iter_base_keys():
|
|
if k in self._deleted:
|
|
continue
|
|
yield k
|
|
for k in self._overrides.keys():
|
|
yield k
|
|
|
|
def __len__(self) -> int:
|
|
base_keys = list(self._iter_base_keys())
|
|
return len(base_keys) - len(self._deleted) + len(self._overrides)
|
|
|
|
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
|
if key in self._overrides:
|
|
return self._overrides.pop(key)
|
|
base_key = self._resolve_base_key(key)
|
|
if base_key is None or key in self._deleted:
|
|
if default is _MISSING:
|
|
raise KeyError(key)
|
|
return default
|
|
if self._mutate_base:
|
|
try:
|
|
return self._base.pop(base_key)
|
|
except KeyError:
|
|
if default is _MISSING:
|
|
raise
|
|
return default
|
|
self._deleted.add(key)
|
|
return self.get_tensor(key)
|
|
|
|
def meta(self, key: str):
|
|
if key in self._overrides:
|
|
t = self._overrides[key]
|
|
numel = t.numel()
|
|
return SimpleNamespace(
|
|
dtype=t.dtype,
|
|
shape=tuple(t.shape),
|
|
numel=numel,
|
|
nbytes=numel * t.element_size(),
|
|
)
|
|
base_key = self._resolve_base_key(key)
|
|
if base_key is None or key in self._deleted:
|
|
raise KeyError(key)
|
|
if hasattr(self._base, "meta"):
|
|
return self._base.meta(base_key)
|
|
t = self._base[base_key]
|
|
numel = t.numel()
|
|
return SimpleNamespace(
|
|
dtype=t.dtype,
|
|
shape=tuple(t.shape),
|
|
numel=numel,
|
|
nbytes=numel * t.element_size(),
|
|
)
|
|
|
|
|
|
class DeviceViewStateDict(_BaseViewStateDict):
|
|
def __init__(
|
|
self,
|
|
base: MutableMapping,
|
|
device: torch.device,
|
|
allow_gds: Optional[bool] = None,
|
|
pin_if_cpu: bool = False,
|
|
mutate_base: bool = False,
|
|
):
|
|
super().__init__(base, mutate_base=mutate_base)
|
|
self._device = device
|
|
self._allow_gds = allow_gds
|
|
self._pin_if_cpu = pin_if_cpu
|
|
|
|
def get_tensor(
|
|
self,
|
|
key: str,
|
|
*,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
allow_gds: Optional[bool] = None,
|
|
pin_if_cpu: bool = False,
|
|
) -> torch.Tensor:
|
|
device = self._device if device is None else device
|
|
allow_gds = self._allow_gds if allow_gds is None else allow_gds
|
|
pin_if_cpu = self._pin_if_cpu if not pin_if_cpu else pin_if_cpu
|
|
return super().get_tensor(
|
|
key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
|
|
)
|
|
|
|
def meta(self, key: str):
|
|
if key in self._overrides:
|
|
t = self._overrides[key]
|
|
numel = t.numel()
|
|
return SimpleNamespace(
|
|
dtype=t.dtype,
|
|
shape=tuple(t.shape),
|
|
numel=numel,
|
|
nbytes=numel * t.element_size(),
|
|
)
|
|
base_key = self._resolve_base_key(key)
|
|
if base_key is None or key in self._deleted:
|
|
raise KeyError(key)
|
|
if hasattr(self._base, "meta"):
|
|
return self._base.meta(base_key)
|
|
t = self._base[base_key]
|
|
numel = t.numel()
|
|
return SimpleNamespace(
|
|
dtype=t.dtype,
|
|
shape=tuple(t.shape),
|
|
numel=numel,
|
|
nbytes=numel * t.element_size(),
|
|
)
|
|
|
|
def __getitem__(self, key: str) -> torch.Tensor:
|
|
return self.get_tensor(key)
|
|
|
|
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
|
base_key = self._resolve_base_key(key)
|
|
if self._mutate_base and base_key is not None and base_key in self._base:
|
|
self._base[base_key] = value
|
|
else:
|
|
self._overrides[key] = value
|
|
self._deleted.discard(key)
|
|
|
|
def __delitem__(self, key: str) -> None:
|
|
if key in self._overrides:
|
|
del self._overrides[key]
|
|
return
|
|
base_key = self._resolve_base_key(key)
|
|
if base_key is None or key in self._deleted:
|
|
raise KeyError(key)
|
|
if self._mutate_base and base_key in self._base:
|
|
del self._base[base_key]
|
|
else:
|
|
self._deleted.add(key)
|
|
|
|
def __iter__(self) -> Iterator[str]:
|
|
for k in self._iter_base_keys():
|
|
if k in self._deleted:
|
|
continue
|
|
yield k
|
|
for k in self._overrides.keys():
|
|
yield k
|
|
|
|
def __len__(self) -> int:
|
|
base_keys = list(self._iter_base_keys())
|
|
return len(base_keys) - len(self._deleted) + len(self._overrides)
|
|
|
|
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
|
if key in self._overrides:
|
|
return self._overrides.pop(key)
|
|
base_key = self._resolve_base_key(key)
|
|
if base_key is None or key in self._deleted:
|
|
if default is _MISSING:
|
|
raise KeyError(key)
|
|
return default
|
|
if self._mutate_base:
|
|
try:
|
|
return self._base.pop(base_key)
|
|
except KeyError:
|
|
if default is _MISSING:
|
|
raise
|
|
return default
|
|
self._deleted.add(key)
|
|
return self.get_tensor(key)
|
|
|
|
|
|
class FilterViewStateDict(_BaseViewStateDict):
|
|
def __init__(self, base: MutableMapping, predicate, mutate_base: bool = False):
|
|
super().__init__(base, mutate_base=mutate_base)
|
|
self._predicate = predicate
|
|
|
|
def _resolve_base_key(self, key: str) -> Optional[str]:
|
|
if self._predicate(key):
|
|
return key
|
|
return None
|
|
|
|
def _iter_base_keys(self) -> Iterable[str]:
|
|
for k in self._base.keys():
|
|
if self._predicate(k):
|
|
yield k
|
|
|
|
|
|
class PrefixViewStateDict(_BaseViewStateDict):
|
|
def __init__(self, base: MutableMapping, source_prefix: str, target_prefix: str = "", mutate_base: bool = False):
|
|
super().__init__(base, mutate_base=mutate_base)
|
|
self._source_prefix = source_prefix
|
|
self._target_prefix = target_prefix
|
|
self._mapping: Dict[str, str] = {}
|
|
self._reverse: Dict[str, str] = {}
|
|
for k in base.keys():
|
|
if not k.startswith(source_prefix):
|
|
continue
|
|
view_key = f"{target_prefix}{k[len(source_prefix):]}"
|
|
self._mapping[k] = view_key
|
|
self._reverse[view_key] = k
|
|
|
|
def _resolve_base_key(self, key: str) -> Optional[str]:
|
|
return self._reverse.get(key)
|
|
|
|
def _iter_base_keys(self) -> Iterable[str]:
|
|
return self._reverse.keys()
|
|
|
|
|
|
class RenameViewStateDict(_BaseViewStateDict):
|
|
def __init__(
|
|
self,
|
|
base: MutableMapping,
|
|
replace_prefix: Mapping[str, str],
|
|
filter_keys: bool = False,
|
|
mutate_base: bool = False,
|
|
):
|
|
super().__init__(base, mutate_base=mutate_base)
|
|
self._filter_keys = filter_keys
|
|
self._replace = list(replace_prefix.items())
|
|
self._mapping: Dict[str, str] = {}
|
|
self._reverse: Dict[str, str] = {}
|
|
for k in base.keys():
|
|
view_key = self._replace_key(k)
|
|
if view_key is None:
|
|
continue
|
|
self._mapping[k] = view_key
|
|
self._reverse[view_key] = k
|
|
|
|
def _replace_key(self, key: str) -> Optional[str]:
|
|
for rp, dst in self._replace:
|
|
if key.startswith(rp):
|
|
return f"{dst}{key[len(rp):]}"
|
|
if self._filter_keys:
|
|
return None
|
|
return key
|
|
|
|
def _resolve_base_key(self, key: str) -> Optional[str]:
|
|
return self._reverse.get(key)
|
|
|
|
def _iter_base_keys(self) -> Iterable[str]:
|
|
return self._reverse.keys()
|
|
|
|
|
|
class MergedStateDict(MutableMapping):
|
|
is_stream_state_dict = True
|
|
|
|
def __init__(self, *mappings: MutableMapping):
|
|
self._mappings = list(mappings)
|
|
self._overrides: Dict[str, torch.Tensor] = {}
|
|
self._deleted: set[str] = set()
|
|
|
|
def __getitem__(self, key: str) -> torch.Tensor:
|
|
if key in self._overrides:
|
|
return self._overrides[key]
|
|
if key in self._deleted:
|
|
raise KeyError(key)
|
|
for mapping in reversed(self._mappings):
|
|
if key in mapping:
|
|
if hasattr(mapping, "get_tensor"):
|
|
return mapping.get_tensor(key)
|
|
return mapping[key]
|
|
raise KeyError(key)
|
|
|
|
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
|
self._overrides[key] = value
|
|
self._deleted.discard(key)
|
|
|
|
def __delitem__(self, key: str) -> None:
|
|
if key in self._overrides:
|
|
del self._overrides[key]
|
|
return
|
|
if key in self._deleted:
|
|
raise KeyError(key)
|
|
if any(key in mapping for mapping in self._mappings):
|
|
self._deleted.add(key)
|
|
return
|
|
raise KeyError(key)
|
|
|
|
def __iter__(self) -> Iterator[str]:
|
|
seen = set()
|
|
for mapping in self._mappings:
|
|
for key in mapping.keys():
|
|
if key in self._deleted or key in seen:
|
|
continue
|
|
seen.add(key)
|
|
yield key
|
|
for key in self._overrides.keys():
|
|
if key not in seen:
|
|
yield key
|
|
|
|
def __len__(self) -> int:
|
|
return len(list(self.__iter__()))
|
|
|
|
def meta(self, key: str):
|
|
if key in self._overrides:
|
|
t = self._overrides[key]
|
|
numel = t.numel()
|
|
return SimpleNamespace(
|
|
dtype=t.dtype,
|
|
shape=tuple(t.shape),
|
|
numel=numel,
|
|
nbytes=numel * t.element_size(),
|
|
)
|
|
if key in self._deleted:
|
|
raise KeyError(key)
|
|
for mapping in reversed(self._mappings):
|
|
if key in mapping:
|
|
if hasattr(mapping, "meta"):
|
|
return mapping.meta(key)
|
|
t = mapping[key]
|
|
numel = t.numel()
|
|
return SimpleNamespace(
|
|
dtype=t.dtype,
|
|
shape=tuple(t.shape),
|
|
numel=numel,
|
|
nbytes=numel * t.element_size(),
|
|
)
|
|
raise KeyError(key)
|
|
|
|
|
|
class MappedStateDict(_BaseViewStateDict):
|
|
def __init__(self, base: MutableMapping, key_map: Mapping[str, str], mutate_base: bool = False):
|
|
super().__init__(base, mutate_base=mutate_base)
|
|
self._base_to_view = dict(key_map)
|
|
self._view_to_base = {v: k for k, v in key_map.items()}
|
|
|
|
def _resolve_base_key(self, key: str) -> Optional[str]:
|
|
return self._view_to_base.get(key)
|
|
|
|
def _iter_base_keys(self) -> Iterable[str]:
|
|
return self._view_to_base.keys()
|