Implement seek and read for pins

Source pins from an mmap is pad because its its a CPU->CPU copy that
attempts to fully buffer the same data twice. Instead, use seek and
read which avoids the mmap buffering while usually being a faster
read in the first place (avoiding mmap faulting etc).
This commit is contained in:
Rattus 2026-03-09 22:53:23 +10:00
parent 4a8cf359fe
commit 6fbbcc4cb7
3 changed files with 52 additions and 3 deletions

View File

@ -1,9 +1,51 @@
import math
import ctypes
import threading
import torch
from typing import NamedTuple
import logging
from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
thread_id: int
offset: int
size: int
def read_tensor_file_slice_into(tensor, destination):
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
if info is None:
return False
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size):
return False
if info.size == 0:
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
file_obj.seek(info.offset)
done = 0
while done < info.size:
n = file_obj.readinto(view[done:])
if n <= 0:
return False
done += n
return True
finally:
view.release()
class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype

View File

@ -1225,6 +1225,8 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
dest_view = dest_views.pop(0)
if tensor is None:
continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue
dest_view.copy_(tensor, non_blocking=non_blocking)

View File

@ -33,6 +33,7 @@ from comfy.cli_args import args
import json
import time
import mmap
import threading
import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
@ -81,14 +82,14 @@ _TYPES = {
}
def load_safetensors(ckpt):
f = open(ckpt, "rb")
f = open(ckpt, "rb", buffering=0)
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
mv = memoryview(mapping)
header_size = struct.unpack("<Q", mapping[:8])[0]
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
mv = mv[8 + header_size:]
mv = mv[(data_base_offset := 8 + header_size):]
sd = {}
for name, info in header.items():
@ -102,7 +103,11 @@ def load_safetensors(ckpt):
with warnings.catch_warnings():
#We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable")
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
setattr(tensor.untyped_storage(),
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
sd[name] = tensor
return sd, header.get("__metadata__", {}),