""" This file is part of ComfyUI. Copyright (C) 2024 Comfy This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ from __future__ import annotations import collections import ctypes import importlib import importlib.util import os import threading from dataclasses import dataclass from types import SimpleNamespace from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple import torch _FST_MODULE = None _FST_LOCK = threading.Lock() _FST_LOADED = False _GDS_INITIALIZED = False _MISSING = object() _NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024 def _require_fastsafetensors(): global _FST_MODULE with _FST_LOCK: if _FST_MODULE is None: if importlib.util.find_spec("fastsafetensors") is None: raise ImportError( "fastsafetensors is required for safetensors streaming. " "Install it with: pip install 'fastsafetensors @ https://github.com/" "foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip'" ) _FST_MODULE = importlib.import_module("fastsafetensors") return _FST_MODULE def _init_fastsafetensors_lib(): global _FST_LOADED fst = _require_fastsafetensors() if not _FST_LOADED: fst.cpp.load_library_functions() _FST_LOADED = True return fst def _init_gds(): global _GDS_INITIALIZED fst = _init_fastsafetensors_lib() if not _GDS_INITIALIZED: if fst.cpp.init_gds() != 0: raise RuntimeError("fastsafetensors init_gds() failed") _GDS_INITIALIZED = True @dataclass(frozen=True) class TensorMeta: dtype: torch.dtype shape: Tuple[int, ...] numel: int nbytes: int data_offsets: Tuple[int, int] filename: str fst_dtype: object strides: Tuple[int, ...] class SafeTensorIndex: def __init__(self, filename: str): fst = _init_fastsafetensors_lib() framework = fst.frameworks.get_framework_op("pytorch") metadata = fst.common.SafeTensorsMetadata.from_file(filename, framework) self._filename = filename self._metadata = metadata self._framework = framework from fastsafetensors.frameworks import _torch as fst_torch self._dtype_map = fst_torch.dtype_convert self._tensor_meta: Dict[str, TensorMeta] = {} for key, frame in metadata.tensors.items(): torch_dtype = self._dtype_map.get(frame.dtype, None) if torch_dtype is None: raise ValueError(f"Unsupported safetensors dtype {frame.dtype} in {filename}") numel = 1 for s in frame.shape: numel *= s nbytes = numel * framework.get_dtype_size(frame.dtype) self._tensor_meta[key] = TensorMeta( dtype=torch_dtype, shape=tuple(frame.shape), numel=numel, nbytes=nbytes, data_offsets=(frame.data_offsets[0], frame.data_offsets[1]), filename=filename, fst_dtype=frame.dtype, strides=tuple(frame.strides), ) def keys(self) -> Iterable[str]: return self._tensor_meta.keys() def has(self, key: str) -> bool: return key in self._tensor_meta def meta(self, key: str) -> TensorMeta: return self._tensor_meta[key] def metadata(self): return self._metadata.metadata @property def header_length(self) -> int: return self._metadata.header_length @property def size_bytes(self) -> int: return self._metadata.size_bytes class _SafeTensorFile: def __init__(self, filename: str, index: SafeTensorIndex): self.filename = filename self.index = index self._fd: Optional[int] = None self._gds_handle = None self._gds_reader = None self._nogds_reader = None self._refcount = 1 def acquire(self) -> "_SafeTensorFile": self._refcount += 1 return self def release(self): self._refcount -= 1 if self._refcount <= 0: self.close() def close(self): if self._fd is not None: os.close(self._fd) self._fd = None self._gds_handle = None def _ensure_fd(self) -> int: if self._fd is None: self._fd = os.open(self.filename, os.O_RDONLY, 0o644) return self._fd def _ensure_nogds_reader(self, use_cuda: bool): fst = _init_fastsafetensors_lib() if self._nogds_reader is None: self._nogds_reader = fst.cpp.nogds_file_reader( False, 16 * 1024, 16, use_cuda ) return self._nogds_reader def _ensure_gds_reader(self, use_cuda: bool): fst = _init_fastsafetensors_lib() if self._gds_reader is None: self._gds_reader = fst.cpp.gds_file_reader(16, use_cuda) return self._gds_reader def _ensure_gds_handle(self, use_cuda: bool): if self._gds_handle is None: fst = _init_fastsafetensors_lib() framework = fst.frameworks.get_framework_op("pytorch") o_direct = _get_gds_o_direct(framework) self._gds_handle = fst.cpp.gds_file_handle(self.filename, o_direct, use_cuda) return self._gds_handle def read_tensor( self, meta: TensorMeta, device: torch.device, dtype: Optional[torch.dtype], allow_gds: bool, pin_if_cpu: bool, ) -> torch.Tensor: fst = _init_fastsafetensors_lib() framework = fst.frameworks.get_framework_op("pytorch") device_is_cuda = device.type == "cuda" if device_is_cuda and allow_gds: _ensure_gds_ready(device) tensor = self._read_tensor_gds( fst, framework, meta, device, dtype ) return tensor cpu_tensor = self._read_tensor_nogds( fst, framework, meta, torch.device("cpu"), dtype ) if device_is_cuda: if pin_if_cpu: cpu_tensor = cpu_tensor.pin_memory() gpu_tensor = torch.empty_like(cpu_tensor, device=device) gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu) return gpu_tensor return cpu_tensor def _aligned_range(self, abs_start: int, length: int) -> Tuple[int, int, int]: fst = _init_fastsafetensors_lib() align = fst.cpp.get_alignment_size() aligned_offset = (abs_start // align) * align head = abs_start - aligned_offset aligned_length = length + head tail = aligned_length % align if tail: aligned_length += align - tail return aligned_offset, aligned_length, head def _read_tensor_nogds( self, fst, framework, meta: TensorMeta, device: torch.device, dtype: Optional[torch.dtype], ) -> torch.Tensor: fd = self._ensure_fd() reader = self._ensure_nogds_reader(use_cuda=False) abs_start = self.index.header_length + meta.data_offsets[0] length = meta.data_offsets[1] - meta.data_offsets[0] chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT)) chunk_bytes = max(1, chunk_bytes) ptr_align = framework.get_device_ptr_align() dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu") buffer_length = 0 buf_ptr = None gbuf = None try: chunk_offset = 0 while chunk_offset < length: chunk_len = min(length - chunk_offset, chunk_bytes) aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len) needed = aligned_length + ptr_align if buf_ptr is None or needed > buffer_length: if buf_ptr is not None: fst.cpp.cpu_free(buf_ptr) buffer_length = needed buf_ptr = fst.cpp.cpu_malloc(buffer_length) gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False) ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off) if reader.wait_read(req) < 0: raise RuntimeError("nogds_file_reader read failed") src_ptr = gbuf.get_base_address() + ptr_off + head dest_ptr = dest_tensor.data_ptr() + chunk_offset ctypes.memmove(dest_ptr, src_ptr, chunk_len) chunk_offset += chunk_len except Exception: if buf_ptr is not None: fst.cpp.cpu_free(buf_ptr) raise if buf_ptr is not None: fst.cpp.cpu_free(buf_ptr) if dtype is not None and dtype != dest_tensor.dtype: _validate_dtype_conversion(dest_tensor.dtype, dtype) dest_tensor = dest_tensor.to(dtype=dtype) return dest_tensor def _read_tensor_gds( self, fst, framework, meta: TensorMeta, device: torch.device, dtype: Optional[torch.dtype], ) -> torch.Tensor: reader = self._ensure_gds_reader(use_cuda=True) handle = self._ensure_gds_handle(use_cuda=True) abs_start = self.index.header_length + meta.data_offsets[0] length = meta.data_offsets[1] - meta.data_offsets[0] aligned_offset, aligned_length, head = self._aligned_range(abs_start, length) ptr_align = framework.get_device_ptr_align() buffer_length = aligned_length + ptr_align fst_device = _fst_device_from_torch(fst, device) gbuf = framework.alloc_tensor_memory(buffer_length, fst_device) ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align file_length = self.index.size_bytes req = reader.submit_read( handle, gbuf, aligned_offset, aligned_length, ptr_off, file_length ) if reader.wait_read(req) < 0: framework.free_tensor_memory(gbuf, fst_device) raise RuntimeError("gds_file_reader read failed") owner = _BufferOwner(lambda: framework.free_tensor_memory(gbuf, fst_device)) tensor = _dlpack_tensor_from_buffer( fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner ) if dtype is not None and dtype != tensor.dtype: _validate_dtype_conversion(tensor.dtype, dtype) tensor = tensor.to(dtype=dtype) return tensor def _fst_device_from_torch(fst, device: torch.device): if device.type == "cuda" and device.index is not None: return fst.st_types.Device.from_str(f"cuda:{device.index}") return fst.st_types.Device.from_str(device.type) class _BufferOwner: def __init__(self, free_fn): self._free_fn = free_fn def __del__(self): try: self._free_fn() except Exception: pass def _dlpack_tensor_from_buffer( fst, framework, ptr: int, meta: TensorMeta, device: torch.device, owner: Optional[_BufferOwner], ) -> torch.Tensor: disk_dtype = framework.as_workaround_dtype(meta.fst_dtype) dev = _fst_device_from_torch(fst, device) dl_tensor = fst.dlpack.from_cuda_buffer(ptr, list(meta.shape), list(meta.strides), disk_dtype, dev) torch_tensor = framework.from_dlpack(dl_tensor, dev, disk_dtype).real_tensor if disk_dtype != meta.fst_dtype: torch_tensor = torch_tensor.view(meta.dtype) if owner is not None: torch_tensor._comfy_disk_buffer_owner = owner return torch_tensor def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype): if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size(): raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})") def _get_gds_o_direct(framework) -> bool: cuda_ver = framework.get_cuda_ver() if cuda_ver and cuda_ver != "0.0": ver_parts = cuda_ver.split("-", 1) if len(ver_parts) == 2: cudavers = list(map(int, ver_parts[1].split("."))) if ver_parts[0] == "cuda": return not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2)) return True return True def _ensure_gds_ready(device: torch.device): fst = _init_fastsafetensors_lib() if not fst.common.is_gpu_found(): raise RuntimeError( "GPUDirect requested but GPU runtime library is missing. " "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." ) gds_supported = fst.cpp.is_gds_supported(device.index if device.index is not None else 0) if gds_supported < 0: raise RuntimeError( "GPUDirect requested but is_gds_supported() failed. " "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." ) if not fst.cpp.is_cufile_found(): raise RuntimeError( "GPUDirect requested but libcufile is missing. " "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." ) if gds_supported == 0: raise RuntimeError( "GPUDirect requested but GDS is unsupported on this platform. " "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." ) _init_gds() class StreamStateDict(collections.abc.MutableMapping): is_stream_state_dict = True def __init__( self, index: SafeTensorIndex, file: _SafeTensorFile, device: torch.device, allow_gds: bool = False, ): self._index = index self._file = file self._device = device self._allow_gds = allow_gds self._overrides: Dict[str, torch.Tensor] = {} self._deleted: set[str] = set() @classmethod def from_file(cls, filename: str, device: torch.device, allow_gds: bool = False) -> "StreamStateDict": index = SafeTensorIndex(filename) file = _SafeTensorFile(filename, index) return cls(index, file, device, allow_gds=allow_gds) def close(self): if self._file is not None: self._file.release() self._file = None def __del__(self): try: self.close() except Exception: pass def meta(self, key: str) -> TensorMeta: if key in self._overrides: t = self._overrides[key] numel = t.numel() return TensorMeta( dtype=t.dtype, shape=tuple(t.shape), numel=numel, nbytes=numel * t.element_size(), data_offsets=(0, numel * t.element_size()), filename="", fst_dtype=None, strides=tuple(t.stride()), ) if key in self._deleted: raise KeyError(key) return self._index.meta(key) def get_tensor( self, key: str, *, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, allow_gds: Optional[bool] = None, pin_if_cpu: bool = False, ) -> torch.Tensor: if key in self._overrides: t = self._overrides[key] if device is not None and t.device != device: t = t.to(device=device) if dtype is not None and t.dtype != dtype: _validate_dtype_conversion(t.dtype, dtype) t = t.to(dtype=dtype) return t if key in self._deleted: raise KeyError(key) if device is None: device = self._device if device.type == "meta": meta = self._index.meta(key) target_dtype = dtype or meta.dtype if dtype is not None and dtype != meta.dtype: _validate_dtype_conversion(meta.dtype, dtype) return torch.empty(meta.shape, dtype=target_dtype, device="meta") if allow_gds is None: allow_gds = self._allow_gds meta = self._index.meta(key) return self._file.read_tensor(meta, device, dtype, allow_gds, pin_if_cpu) def __getitem__(self, key: str) -> torch.Tensor: return self.get_tensor(key) def __setitem__(self, key: str, value: torch.Tensor) -> None: self._overrides[key] = value self._deleted.discard(key) def __delitem__(self, key: str) -> None: if key in self._overrides: del self._overrides[key] return if key in self._deleted: raise KeyError(key) if self._index.has(key): self._deleted.add(key) return raise KeyError(key) def __iter__(self) -> Iterator[str]: for k in self._index.keys(): if k in self._deleted: continue if k in self._overrides: continue yield k for k in self._overrides.keys(): yield k def __len__(self) -> int: base = len(self._index.keys()) return base - len(self._deleted) + len(self._overrides) def __contains__(self, key: object) -> bool: if not isinstance(key, str): return False if key in self._deleted: return False if key in self._overrides: return True return self._index.has(key) def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: if key in self._overrides: return self._overrides.pop(key) if key in self._deleted: if default is _MISSING: raise KeyError(key) return default if self._index.has(key): self._deleted.add(key) return self.get_tensor(key) if default is _MISSING: raise KeyError(key) return default def copy(self) -> "StreamStateDict": new = StreamStateDict(self._index, self._file.acquire(), self._device, allow_gds=self._allow_gds) new._overrides = dict(self._overrides) new._deleted = set(self._deleted) return new def metadata(self): return self._index.metadata() class _BaseViewStateDict(MutableMapping): is_stream_state_dict = True def __init__(self, base: MutableMapping, mutate_base: bool = False): self._base = base self._mutate_base = mutate_base self._overrides: Dict[str, torch.Tensor] = {} self._deleted: set[str] = set() def _resolve_base_key(self, key: str) -> Optional[str]: return key def _iter_base_keys(self) -> Iterable[str]: return self._base.keys() def get_tensor( self, key: str, *, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, allow_gds: Optional[bool] = None, pin_if_cpu: bool = False, ) -> torch.Tensor: if key in self._overrides: t = self._overrides[key] if device is not None and t.device != device: t = t.to(device=device) if dtype is not None and t.dtype != dtype: _validate_dtype_conversion(t.dtype, dtype) t = t.to(dtype=dtype) return t base_key = self._resolve_base_key(key) if base_key is None or key in self._deleted: raise KeyError(key) if hasattr(self._base, "get_tensor"): return self._base.get_tensor( base_key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu ) t = self._base[base_key] if device is not None and t.device != device: t = t.to(device=device) if dtype is not None and t.dtype != dtype: _validate_dtype_conversion(t.dtype, dtype) t = t.to(dtype=dtype) return t def __getitem__(self, key: str) -> torch.Tensor: return self.get_tensor(key) def __setitem__(self, key: str, value: torch.Tensor) -> None: base_key = self._resolve_base_key(key) if self._mutate_base and base_key is not None and base_key in self._base: self._base[base_key] = value else: self._overrides[key] = value self._deleted.discard(key) def __delitem__(self, key: str) -> None: if key in self._overrides: del self._overrides[key] return base_key = self._resolve_base_key(key) if base_key is None or key in self._deleted: raise KeyError(key) if self._mutate_base and base_key in self._base: del self._base[base_key] else: self._deleted.add(key) def __iter__(self) -> Iterator[str]: for k in self._iter_base_keys(): if k in self._deleted: continue yield k for k in self._overrides.keys(): yield k def __len__(self) -> int: base_keys = list(self._iter_base_keys()) return len(base_keys) - len(self._deleted) + len(self._overrides) def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: if key in self._overrides: return self._overrides.pop(key) base_key = self._resolve_base_key(key) if base_key is None or key in self._deleted: if default is _MISSING: raise KeyError(key) return default if self._mutate_base: try: return self._base.pop(base_key) except KeyError: if default is _MISSING: raise return default self._deleted.add(key) return self.get_tensor(key) def meta(self, key: str): if key in self._overrides: t = self._overrides[key] numel = t.numel() return SimpleNamespace( dtype=t.dtype, shape=tuple(t.shape), numel=numel, nbytes=numel * t.element_size(), ) base_key = self._resolve_base_key(key) if base_key is None or key in self._deleted: raise KeyError(key) if hasattr(self._base, "meta"): return self._base.meta(base_key) t = self._base[base_key] numel = t.numel() return SimpleNamespace( dtype=t.dtype, shape=tuple(t.shape), numel=numel, nbytes=numel * t.element_size(), ) class DeviceViewStateDict(_BaseViewStateDict): def __init__( self, base: MutableMapping, device: torch.device, allow_gds: Optional[bool] = None, pin_if_cpu: bool = False, mutate_base: bool = False, ): super().__init__(base, mutate_base=mutate_base) self._device = device self._allow_gds = allow_gds self._pin_if_cpu = pin_if_cpu def get_tensor( self, key: str, *, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, allow_gds: Optional[bool] = None, pin_if_cpu: bool = False, ) -> torch.Tensor: device = self._device if device is None else device allow_gds = self._allow_gds if allow_gds is None else allow_gds pin_if_cpu = self._pin_if_cpu if not pin_if_cpu else pin_if_cpu return super().get_tensor( key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu ) def meta(self, key: str): if key in self._overrides: t = self._overrides[key] numel = t.numel() return SimpleNamespace( dtype=t.dtype, shape=tuple(t.shape), numel=numel, nbytes=numel * t.element_size(), ) base_key = self._resolve_base_key(key) if base_key is None or key in self._deleted: raise KeyError(key) if hasattr(self._base, "meta"): return self._base.meta(base_key) t = self._base[base_key] numel = t.numel() return SimpleNamespace( dtype=t.dtype, shape=tuple(t.shape), numel=numel, nbytes=numel * t.element_size(), ) def __getitem__(self, key: str) -> torch.Tensor: return self.get_tensor(key) def __setitem__(self, key: str, value: torch.Tensor) -> None: base_key = self._resolve_base_key(key) if self._mutate_base and base_key is not None and base_key in self._base: self._base[base_key] = value else: self._overrides[key] = value self._deleted.discard(key) def __delitem__(self, key: str) -> None: if key in self._overrides: del self._overrides[key] return base_key = self._resolve_base_key(key) if base_key is None or key in self._deleted: raise KeyError(key) if self._mutate_base and base_key in self._base: del self._base[base_key] else: self._deleted.add(key) def __iter__(self) -> Iterator[str]: for k in self._iter_base_keys(): if k in self._deleted: continue yield k for k in self._overrides.keys(): yield k def __len__(self) -> int: base_keys = list(self._iter_base_keys()) return len(base_keys) - len(self._deleted) + len(self._overrides) def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: if key in self._overrides: return self._overrides.pop(key) base_key = self._resolve_base_key(key) if base_key is None or key in self._deleted: if default is _MISSING: raise KeyError(key) return default if self._mutate_base: try: return self._base.pop(base_key) except KeyError: if default is _MISSING: raise return default self._deleted.add(key) return self.get_tensor(key) class FilterViewStateDict(_BaseViewStateDict): def __init__(self, base: MutableMapping, predicate, mutate_base: bool = False): super().__init__(base, mutate_base=mutate_base) self._predicate = predicate def _resolve_base_key(self, key: str) -> Optional[str]: if self._predicate(key): return key return None def _iter_base_keys(self) -> Iterable[str]: for k in self._base.keys(): if self._predicate(k): yield k class PrefixViewStateDict(_BaseViewStateDict): def __init__(self, base: MutableMapping, source_prefix: str, target_prefix: str = "", mutate_base: bool = False): super().__init__(base, mutate_base=mutate_base) self._source_prefix = source_prefix self._target_prefix = target_prefix self._mapping: Dict[str, str] = {} self._reverse: Dict[str, str] = {} for k in base.keys(): if not k.startswith(source_prefix): continue view_key = f"{target_prefix}{k[len(source_prefix):]}" self._mapping[k] = view_key self._reverse[view_key] = k def _resolve_base_key(self, key: str) -> Optional[str]: return self._reverse.get(key) def _iter_base_keys(self) -> Iterable[str]: return self._reverse.keys() class RenameViewStateDict(_BaseViewStateDict): def __init__( self, base: MutableMapping, replace_prefix: Mapping[str, str], filter_keys: bool = False, mutate_base: bool = False, ): super().__init__(base, mutate_base=mutate_base) self._filter_keys = filter_keys self._replace = list(replace_prefix.items()) self._mapping: Dict[str, str] = {} self._reverse: Dict[str, str] = {} for k in base.keys(): view_key = self._replace_key(k) if view_key is None: continue self._mapping[k] = view_key self._reverse[view_key] = k def _replace_key(self, key: str) -> Optional[str]: for rp, dst in self._replace: if key.startswith(rp): return f"{dst}{key[len(rp):]}" if self._filter_keys: return None return key def _resolve_base_key(self, key: str) -> Optional[str]: return self._reverse.get(key) def _iter_base_keys(self) -> Iterable[str]: return self._reverse.keys() class MergedStateDict(MutableMapping): is_stream_state_dict = True def __init__(self, *mappings: MutableMapping): self._mappings = list(mappings) self._overrides: Dict[str, torch.Tensor] = {} self._deleted: set[str] = set() def __getitem__(self, key: str) -> torch.Tensor: if key in self._overrides: return self._overrides[key] if key in self._deleted: raise KeyError(key) for mapping in reversed(self._mappings): if key in mapping: if hasattr(mapping, "get_tensor"): return mapping.get_tensor(key) return mapping[key] raise KeyError(key) def __setitem__(self, key: str, value: torch.Tensor) -> None: self._overrides[key] = value self._deleted.discard(key) def __delitem__(self, key: str) -> None: if key in self._overrides: del self._overrides[key] return if key in self._deleted: raise KeyError(key) if any(key in mapping for mapping in self._mappings): self._deleted.add(key) return raise KeyError(key) def __iter__(self) -> Iterator[str]: seen = set() for mapping in self._mappings: for key in mapping.keys(): if key in self._deleted or key in seen: continue seen.add(key) yield key for key in self._overrides.keys(): if key not in seen: yield key def __len__(self) -> int: return len(list(self.__iter__())) def meta(self, key: str): if key in self._overrides: t = self._overrides[key] numel = t.numel() return SimpleNamespace( dtype=t.dtype, shape=tuple(t.shape), numel=numel, nbytes=numel * t.element_size(), ) if key in self._deleted: raise KeyError(key) for mapping in reversed(self._mappings): if key in mapping: if hasattr(mapping, "meta"): return mapping.meta(key) t = mapping[key] numel = t.numel() return SimpleNamespace( dtype=t.dtype, shape=tuple(t.shape), numel=numel, nbytes=numel * t.element_size(), ) raise KeyError(key) class MappedStateDict(_BaseViewStateDict): def __init__(self, base: MutableMapping, key_map: Mapping[str, str], mutate_base: bool = False): super().__init__(base, mutate_base=mutate_base) self._base_to_view = dict(key_map) self._view_to_base = {v: k for k, v in key_map.items()} def _resolve_base_key(self, key: str) -> Optional[str]: return self._view_to_base.get(key) def _iter_base_keys(self) -> Iterable[str]: return self._view_to_base.keys()