mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
utils: bypass safetensors mmap when disabled
Load safetensors through a direct read path under --disable-mmap so unified-memory systems avoid retaining mmap-backed file pages alongside framework tensors. Made-with: Cursor
This commit is contained in:
parent
fce0398470
commit
1ff364873a
@ -119,6 +119,76 @@ def load_safetensors(ckpt):
|
||||
return sd, header.get("__metadata__", {}),
|
||||
|
||||
|
||||
def load_safetensors_no_mmap(ckpt, device=None, return_metadata=False):
|
||||
# Load a .safetensors / .sft file without ever mmap'ing it.
|
||||
#
|
||||
# safetensors.safe_open() (and therefore safetensors.torch.load_file) always
|
||||
# mmaps the underlying file in Rust. On systems with unified CPU/GPU memory
|
||||
# like NVIDIA Grace Blackwell / DGX Spark, Apple Silicon, AMD APUs, etc.
|
||||
# this is fatal for large models: the OS page-cache pages backing the mmap
|
||||
# and any subsequent device copy both reside in the same physical memory
|
||||
# pool, doubling peak memory and causing OOM well before the hardware
|
||||
# limit is reached.
|
||||
# See: https://github.com/Comfy-Org/ComfyUI/issues/10896
|
||||
# https://github.com/safetensors/safetensors/issues/758
|
||||
# https://github.com/safetensors/safetensors/pull/759
|
||||
#
|
||||
# This is a temporary workaround until upstream safetensors exposes a
|
||||
# public ``mmap=False`` option. Here we parse the safetensors header
|
||||
# ourselves and read each tensor straight from disk into a per-tensor
|
||||
# ``bytearray`` via ``readinto``, then zero-copy-wrap it as a torch tensor
|
||||
# with ``torch.frombuffer``. Peak memory is one model copy (plus, if a
|
||||
# non-CPU device is requested, the bytes of a single tensor in flight
|
||||
# while it is being moved).
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
|
||||
sd = {}
|
||||
metadata = None
|
||||
with open(ckpt, "rb") as f:
|
||||
header_bytes = f.read(8)
|
||||
if len(header_bytes) != 8:
|
||||
raise ValueError("HeaderTooLarge: file is too small to be a valid safetensors file: {}".format(ckpt))
|
||||
header_size = struct.unpack("<Q", header_bytes)[0]
|
||||
header_data = f.read(header_size)
|
||||
if len(header_data) != header_size:
|
||||
raise ValueError("MetadataIncompleteBuffer: truncated header in {}".format(ckpt))
|
||||
header = json.loads(header_data.decode("utf-8"))
|
||||
data_base_offset = 8 + header_size
|
||||
|
||||
if return_metadata:
|
||||
metadata = header.get("__metadata__", {})
|
||||
|
||||
for name, info in header.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
|
||||
dtype = _TYPES[info["dtype"]]
|
||||
shape = info["shape"]
|
||||
start, end = info["data_offsets"]
|
||||
num_bytes = end - start
|
||||
|
||||
if num_bytes == 0:
|
||||
tensor = torch.empty(shape, dtype=dtype)
|
||||
else:
|
||||
buf = bytearray(num_bytes)
|
||||
f.seek(data_base_offset + start)
|
||||
view = memoryview(buf)
|
||||
offset = 0
|
||||
while offset < num_bytes:
|
||||
n = f.readinto(view[offset:])
|
||||
if not n:
|
||||
raise ValueError("MetadataIncompleteBuffer: unexpected EOF reading tensor {!r} from {}".format(name, ckpt))
|
||||
offset += n
|
||||
tensor = torch.frombuffer(buf, dtype=dtype).reshape(shape)
|
||||
|
||||
if device.type != "cpu":
|
||||
tensor = tensor.to(device=device)
|
||||
sd[name] = tensor
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
@ -129,14 +199,15 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
sd, metadata = load_safetensors(ckpt)
|
||||
if not return_metadata:
|
||||
metadata = None
|
||||
elif DISABLE_MMAP:
|
||||
sd, metadata = load_safetensors_no_mmap(ckpt, device=device, return_metadata=return_metadata)
|
||||
if not return_metadata:
|
||||
metadata = None
|
||||
else:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
sd[k] = f.get_tensor(k)
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user