mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
Chunk nogds tensor reads and fix view mappings
This commit is contained in:
parent
45a77073ac
commit
010d5445fe
@ -19,6 +19,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import ctypes
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
@ -35,6 +36,7 @@ _FST_LOCK = threading.Lock()
|
||||
_FST_LOADED = False
|
||||
_GDS_INITIALIZED = False
|
||||
_MISSING = object()
|
||||
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
|
||||
|
||||
|
||||
def _require_fastsafetensors():
|
||||
@ -236,24 +238,46 @@ class _SafeTensorFile:
|
||||
reader = self._ensure_nogds_reader(use_cuda=False)
|
||||
abs_start = self.index.header_length + meta.data_offsets[0]
|
||||
length = meta.data_offsets[1] - meta.data_offsets[0]
|
||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
|
||||
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
|
||||
chunk_bytes = max(1, chunk_bytes)
|
||||
ptr_align = framework.get_device_ptr_align()
|
||||
buffer_length = aligned_length + ptr_align
|
||||
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
|
||||
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
|
||||
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
|
||||
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
|
||||
if reader.wait_read(req) < 0:
|
||||
disk_dtype = framework.as_workaround_dtype(meta.fst_dtype)
|
||||
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=disk_dtype, device="cpu")
|
||||
buffer_length = 0
|
||||
buf_ptr = None
|
||||
gbuf = None
|
||||
try:
|
||||
chunk_offset = 0
|
||||
while chunk_offset < length:
|
||||
chunk_len = min(length - chunk_offset, chunk_bytes)
|
||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
|
||||
needed = aligned_length + ptr_align
|
||||
if buf_ptr is None or needed > buffer_length:
|
||||
if buf_ptr is not None:
|
||||
fst.cpp.cpu_free(buf_ptr)
|
||||
buffer_length = needed
|
||||
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
|
||||
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
|
||||
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
|
||||
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
|
||||
if reader.wait_read(req) < 0:
|
||||
raise RuntimeError("nogds_file_reader read failed")
|
||||
src_ptr = gbuf.get_base_address() + ptr_off + head
|
||||
dest_ptr = dest_tensor.data_ptr() + chunk_offset
|
||||
ctypes.memmove(dest_ptr, src_ptr, chunk_len)
|
||||
chunk_offset += chunk_len
|
||||
except Exception:
|
||||
if buf_ptr is not None:
|
||||
fst.cpp.cpu_free(buf_ptr)
|
||||
raise
|
||||
if buf_ptr is not None:
|
||||
fst.cpp.cpu_free(buf_ptr)
|
||||
raise RuntimeError("nogds_file_reader read failed")
|
||||
owner = _BufferOwner(lambda: fst.cpp.cpu_free(buf_ptr))
|
||||
tensor = _dlpack_tensor_from_buffer(
|
||||
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
|
||||
)
|
||||
if dtype is not None and dtype != tensor.dtype:
|
||||
_validate_dtype_conversion(tensor.dtype, dtype)
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
return tensor
|
||||
if disk_dtype != meta.dtype:
|
||||
dest_tensor = dest_tensor.view(meta.dtype)
|
||||
if dtype is not None and dtype != dest_tensor.dtype:
|
||||
_validate_dtype_conversion(dest_tensor.dtype, dtype)
|
||||
dest_tensor = dest_tensor.to(dtype=dtype)
|
||||
return dest_tensor
|
||||
|
||||
def _read_tensor_gds(
|
||||
self,
|
||||
@ -565,6 +589,59 @@ class _BaseViewStateDict(MutableMapping):
|
||||
t = t.to(dtype=dtype)
|
||||
return t
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
return self.get_tensor(key)
|
||||
|
||||
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
||||
base_key = self._resolve_base_key(key)
|
||||
if self._mutate_base and base_key is not None and base_key in self._base:
|
||||
self._base[base_key] = value
|
||||
else:
|
||||
self._overrides[key] = value
|
||||
self._deleted.discard(key)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if self._mutate_base and base_key in self._base:
|
||||
del self._base[base_key]
|
||||
else:
|
||||
self._deleted.add(key)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for k in self._iter_base_keys():
|
||||
if k in self._deleted:
|
||||
continue
|
||||
yield k
|
||||
for k in self._overrides.keys():
|
||||
yield k
|
||||
|
||||
def __len__(self) -> int:
|
||||
base_keys = list(self._iter_base_keys())
|
||||
return len(base_keys) - len(self._deleted) + len(self._overrides)
|
||||
|
||||
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
return self._overrides.pop(key)
|
||||
base_key = self._resolve_base_key(key)
|
||||
if base_key is None or key in self._deleted:
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
if self._mutate_base:
|
||||
try:
|
||||
return self._base.pop(base_key)
|
||||
except KeyError:
|
||||
if default is _MISSING:
|
||||
raise
|
||||
return default
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
|
||||
|
||||
class DeviceViewStateDict(_BaseViewStateDict):
|
||||
def __init__(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user