diff --git a/comfy/safetensors_stream.py b/comfy/safetensors_stream.py index f10eac078..f17c12ba1 100644 --- a/comfy/safetensors_stream.py +++ b/comfy/safetensors_stream.py @@ -19,6 +19,7 @@ from __future__ import annotations import collections +import ctypes import importlib import importlib.util import os @@ -35,6 +36,7 @@ _FST_LOCK = threading.Lock() _FST_LOADED = False _GDS_INITIALIZED = False _MISSING = object() +_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024 def _require_fastsafetensors(): @@ -236,24 +238,46 @@ class _SafeTensorFile: reader = self._ensure_nogds_reader(use_cuda=False) abs_start = self.index.header_length + meta.data_offsets[0] length = meta.data_offsets[1] - meta.data_offsets[0] - aligned_offset, aligned_length, head = self._aligned_range(abs_start, length) + chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT)) + chunk_bytes = max(1, chunk_bytes) ptr_align = framework.get_device_ptr_align() - buffer_length = aligned_length + ptr_align - buf_ptr = fst.cpp.cpu_malloc(buffer_length) - gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False) - ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align - req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off) - if reader.wait_read(req) < 0: + disk_dtype = framework.as_workaround_dtype(meta.fst_dtype) + dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=disk_dtype, device="cpu") + buffer_length = 0 + buf_ptr = None + gbuf = None + try: + chunk_offset = 0 + while chunk_offset < length: + chunk_len = min(length - chunk_offset, chunk_bytes) + aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len) + needed = aligned_length + ptr_align + if buf_ptr is None or needed > buffer_length: + if buf_ptr is not None: + fst.cpp.cpu_free(buf_ptr) + buffer_length = needed + buf_ptr = fst.cpp.cpu_malloc(buffer_length) + gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False) + ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align + req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off) + if reader.wait_read(req) < 0: + raise RuntimeError("nogds_file_reader read failed") + src_ptr = gbuf.get_base_address() + ptr_off + head + dest_ptr = dest_tensor.data_ptr() + chunk_offset + ctypes.memmove(dest_ptr, src_ptr, chunk_len) + chunk_offset += chunk_len + except Exception: + if buf_ptr is not None: + fst.cpp.cpu_free(buf_ptr) + raise + if buf_ptr is not None: fst.cpp.cpu_free(buf_ptr) - raise RuntimeError("nogds_file_reader read failed") - owner = _BufferOwner(lambda: fst.cpp.cpu_free(buf_ptr)) - tensor = _dlpack_tensor_from_buffer( - fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner - ) - if dtype is not None and dtype != tensor.dtype: - _validate_dtype_conversion(tensor.dtype, dtype) - tensor = tensor.to(dtype=dtype) - return tensor + if disk_dtype != meta.dtype: + dest_tensor = dest_tensor.view(meta.dtype) + if dtype is not None and dtype != dest_tensor.dtype: + _validate_dtype_conversion(dest_tensor.dtype, dtype) + dest_tensor = dest_tensor.to(dtype=dtype) + return dest_tensor def _read_tensor_gds( self, @@ -565,6 +589,59 @@ class _BaseViewStateDict(MutableMapping): t = t.to(dtype=dtype) return t + def __getitem__(self, key: str) -> torch.Tensor: + return self.get_tensor(key) + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + base_key = self._resolve_base_key(key) + if self._mutate_base and base_key is not None and base_key in self._base: + self._base[base_key] = value + else: + self._overrides[key] = value + self._deleted.discard(key) + + def __delitem__(self, key: str) -> None: + if key in self._overrides: + del self._overrides[key] + return + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + raise KeyError(key) + if self._mutate_base and base_key in self._base: + del self._base[base_key] + else: + self._deleted.add(key) + + def __iter__(self) -> Iterator[str]: + for k in self._iter_base_keys(): + if k in self._deleted: + continue + yield k + for k in self._overrides.keys(): + yield k + + def __len__(self) -> int: + base_keys = list(self._iter_base_keys()) + return len(base_keys) - len(self._deleted) + len(self._overrides) + + def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: + if key in self._overrides: + return self._overrides.pop(key) + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + if default is _MISSING: + raise KeyError(key) + return default + if self._mutate_base: + try: + return self._base.pop(base_key) + except KeyError: + if default is _MISSING: + raise + return default + self._deleted.add(key) + return self.get_tensor(key) + class DeviceViewStateDict(_BaseViewStateDict): def __init__(