mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 23:30:16 +08:00
Chunk nogds tensor reads and fix view mappings
This commit is contained in:
parent
45a77073ac
commit
010d5445fe
@ -19,6 +19,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import ctypes
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
@ -35,6 +36,7 @@ _FST_LOCK = threading.Lock()
|
|||||||
_FST_LOADED = False
|
_FST_LOADED = False
|
||||||
_GDS_INITIALIZED = False
|
_GDS_INITIALIZED = False
|
||||||
_MISSING = object()
|
_MISSING = object()
|
||||||
|
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
def _require_fastsafetensors():
|
def _require_fastsafetensors():
|
||||||
@ -236,24 +238,46 @@ class _SafeTensorFile:
|
|||||||
reader = self._ensure_nogds_reader(use_cuda=False)
|
reader = self._ensure_nogds_reader(use_cuda=False)
|
||||||
abs_start = self.index.header_length + meta.data_offsets[0]
|
abs_start = self.index.header_length + meta.data_offsets[0]
|
||||||
length = meta.data_offsets[1] - meta.data_offsets[0]
|
length = meta.data_offsets[1] - meta.data_offsets[0]
|
||||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
|
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
|
||||||
|
chunk_bytes = max(1, chunk_bytes)
|
||||||
ptr_align = framework.get_device_ptr_align()
|
ptr_align = framework.get_device_ptr_align()
|
||||||
buffer_length = aligned_length + ptr_align
|
disk_dtype = framework.as_workaround_dtype(meta.fst_dtype)
|
||||||
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
|
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=disk_dtype, device="cpu")
|
||||||
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
|
buffer_length = 0
|
||||||
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
|
buf_ptr = None
|
||||||
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
|
gbuf = None
|
||||||
if reader.wait_read(req) < 0:
|
try:
|
||||||
|
chunk_offset = 0
|
||||||
|
while chunk_offset < length:
|
||||||
|
chunk_len = min(length - chunk_offset, chunk_bytes)
|
||||||
|
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
|
||||||
|
needed = aligned_length + ptr_align
|
||||||
|
if buf_ptr is None or needed > buffer_length:
|
||||||
|
if buf_ptr is not None:
|
||||||
|
fst.cpp.cpu_free(buf_ptr)
|
||||||
|
buffer_length = needed
|
||||||
|
buf_ptr = fst.cpp.cpu_malloc(buffer_length)
|
||||||
|
gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
|
||||||
|
ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
|
||||||
|
req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
|
||||||
|
if reader.wait_read(req) < 0:
|
||||||
|
raise RuntimeError("nogds_file_reader read failed")
|
||||||
|
src_ptr = gbuf.get_base_address() + ptr_off + head
|
||||||
|
dest_ptr = dest_tensor.data_ptr() + chunk_offset
|
||||||
|
ctypes.memmove(dest_ptr, src_ptr, chunk_len)
|
||||||
|
chunk_offset += chunk_len
|
||||||
|
except Exception:
|
||||||
|
if buf_ptr is not None:
|
||||||
|
fst.cpp.cpu_free(buf_ptr)
|
||||||
|
raise
|
||||||
|
if buf_ptr is not None:
|
||||||
fst.cpp.cpu_free(buf_ptr)
|
fst.cpp.cpu_free(buf_ptr)
|
||||||
raise RuntimeError("nogds_file_reader read failed")
|
if disk_dtype != meta.dtype:
|
||||||
owner = _BufferOwner(lambda: fst.cpp.cpu_free(buf_ptr))
|
dest_tensor = dest_tensor.view(meta.dtype)
|
||||||
tensor = _dlpack_tensor_from_buffer(
|
if dtype is not None and dtype != dest_tensor.dtype:
|
||||||
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
|
_validate_dtype_conversion(dest_tensor.dtype, dtype)
|
||||||
)
|
dest_tensor = dest_tensor.to(dtype=dtype)
|
||||||
if dtype is not None and dtype != tensor.dtype:
|
return dest_tensor
|
||||||
_validate_dtype_conversion(tensor.dtype, dtype)
|
|
||||||
tensor = tensor.to(dtype=dtype)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def _read_tensor_gds(
|
def _read_tensor_gds(
|
||||||
self,
|
self,
|
||||||
@ -565,6 +589,59 @@ class _BaseViewStateDict(MutableMapping):
|
|||||||
t = t.to(dtype=dtype)
|
t = t.to(dtype=dtype)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> torch.Tensor:
|
||||||
|
return self.get_tensor(key)
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
||||||
|
base_key = self._resolve_base_key(key)
|
||||||
|
if self._mutate_base and base_key is not None and base_key in self._base:
|
||||||
|
self._base[base_key] = value
|
||||||
|
else:
|
||||||
|
self._overrides[key] = value
|
||||||
|
self._deleted.discard(key)
|
||||||
|
|
||||||
|
def __delitem__(self, key: str) -> None:
|
||||||
|
if key in self._overrides:
|
||||||
|
del self._overrides[key]
|
||||||
|
return
|
||||||
|
base_key = self._resolve_base_key(key)
|
||||||
|
if base_key is None or key in self._deleted:
|
||||||
|
raise KeyError(key)
|
||||||
|
if self._mutate_base and base_key in self._base:
|
||||||
|
del self._base[base_key]
|
||||||
|
else:
|
||||||
|
self._deleted.add(key)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[str]:
|
||||||
|
for k in self._iter_base_keys():
|
||||||
|
if k in self._deleted:
|
||||||
|
continue
|
||||||
|
yield k
|
||||||
|
for k in self._overrides.keys():
|
||||||
|
yield k
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
base_keys = list(self._iter_base_keys())
|
||||||
|
return len(base_keys) - len(self._deleted) + len(self._overrides)
|
||||||
|
|
||||||
|
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
||||||
|
if key in self._overrides:
|
||||||
|
return self._overrides.pop(key)
|
||||||
|
base_key = self._resolve_base_key(key)
|
||||||
|
if base_key is None or key in self._deleted:
|
||||||
|
if default is _MISSING:
|
||||||
|
raise KeyError(key)
|
||||||
|
return default
|
||||||
|
if self._mutate_base:
|
||||||
|
try:
|
||||||
|
return self._base.pop(base_key)
|
||||||
|
except KeyError:
|
||||||
|
if default is _MISSING:
|
||||||
|
raise
|
||||||
|
return default
|
||||||
|
self._deleted.add(key)
|
||||||
|
return self.get_tensor(key)
|
||||||
|
|
||||||
|
|
||||||
class DeviceViewStateDict(_BaseViewStateDict):
|
class DeviceViewStateDict(_BaseViewStateDict):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user