Chunk nogds tensor reads and fix view mappings

This commit is contained in:
ifilipis 2026-01-08 19:22:39 +02:00
parent 45a77073ac
commit 010d5445fe

View File

@ -19,6 +19,7 @@
from __future__ import annotations
import collections
import ctypes
import importlib
import importlib.util
import os
@ -35,6 +36,7 @@ _FST_LOCK = threading.Lock()
_FST_LOADED = False
_GDS_INITIALIZED = False
_MISSING = object()
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
def _require_fastsafetensors():
@ -236,24 +238,46 @@ class _SafeTensorFile:
reader = self._ensure_nogds_reader(use_cuda=False)
abs_start = self.index.header_length + meta.data_offsets[0]
length = meta.data_offsets[1] - meta.data_offsets[0]
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
chunk_bytes = max(1, chunk_bytes)
ptr_align = framework.get_device_ptr_align()
buffer_length = aligned_length + ptr_align
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
if reader.wait_read(req) < 0:
disk_dtype = framework.as_workaround_dtype(meta.fst_dtype)
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=disk_dtype, device="cpu")
buffer_length = 0
buf_ptr = None
gbuf = None
try:
chunk_offset = 0
while chunk_offset < length:
chunk_len = min(length - chunk_offset, chunk_bytes)
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
needed = aligned_length + ptr_align
if buf_ptr is None or needed > buffer_length:
if buf_ptr is not None:
fst.cpp.cpu_free(buf_ptr)
buffer_length = needed
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
if reader.wait_read(req) < 0:
raise RuntimeError("nogds_file_reader read failed")
src_ptr = gbuf.get_base_address() + ptr_off + head
dest_ptr = dest_tensor.data_ptr() + chunk_offset
ctypes.memmove(dest_ptr, src_ptr, chunk_len)
chunk_offset += chunk_len
except Exception:
if buf_ptr is not None:
fst.cpp.cpu_free(buf_ptr)
raise
if buf_ptr is not None:
fst.cpp.cpu_free(buf_ptr)
raise RuntimeError("nogds_file_reader read failed")
owner = _BufferOwner(lambda: fst.cpp.cpu_free(buf_ptr))
tensor = _dlpack_tensor_from_buffer(
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
)
if dtype is not None and dtype != tensor.dtype:
_validate_dtype_conversion(tensor.dtype, dtype)
tensor = tensor.to(dtype=dtype)
return tensor
if disk_dtype != meta.dtype:
dest_tensor = dest_tensor.view(meta.dtype)
if dtype is not None and dtype != dest_tensor.dtype:
_validate_dtype_conversion(dest_tensor.dtype, dtype)
dest_tensor = dest_tensor.to(dtype=dtype)
return dest_tensor
def _read_tensor_gds(
self,
@ -565,6 +589,59 @@ class _BaseViewStateDict(MutableMapping):
t = t.to(dtype=dtype)
return t
def __getitem__(self, key: str) -> torch.Tensor:
return self.get_tensor(key)
def __setitem__(self, key: str, value: torch.Tensor) -> None:
base_key = self._resolve_base_key(key)
if self._mutate_base and base_key is not None and base_key in self._base:
self._base[base_key] = value
else:
self._overrides[key] = value
self._deleted.discard(key)
def __delitem__(self, key: str) -> None:
if key in self._overrides:
del self._overrides[key]
return
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
raise KeyError(key)
if self._mutate_base and base_key in self._base:
del self._base[base_key]
else:
self._deleted.add(key)
def __iter__(self) -> Iterator[str]:
for k in self._iter_base_keys():
if k in self._deleted:
continue
yield k
for k in self._overrides.keys():
yield k
def __len__(self) -> int:
base_keys = list(self._iter_base_keys())
return len(base_keys) - len(self._deleted) + len(self._overrides)
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
if key in self._overrides:
return self._overrides.pop(key)
base_key = self._resolve_base_key(key)
if base_key is None or key in self._deleted:
if default is _MISSING:
raise KeyError(key)
return default
if self._mutate_base:
try:
return self._base.pop(base_key)
except KeyError:
if default is _MISSING:
raise
return default
self._deleted.add(key)
return self.get_tensor(key)
class DeviceViewStateDict(_BaseViewStateDict):
def __init__(