This commit is contained in:
Johnny 2026-05-09 04:28:04 +08:00 committed by GitHub
commit c57ee6a4dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -119,6 +119,76 @@ def load_safetensors(ckpt):
return sd, header.get("__metadata__", {}),
def load_safetensors_no_mmap(ckpt, device=None, return_metadata=False):
# Load a .safetensors / .sft file without ever mmap'ing it.
#
# safetensors.safe_open() (and therefore safetensors.torch.load_file) always
# mmaps the underlying file in Rust. On systems with unified CPU/GPU memory
# like NVIDIA Grace Blackwell / DGX Spark, Apple Silicon, AMD APUs, etc.
# this is fatal for large models: the OS page-cache pages backing the mmap
# and any subsequent device copy both reside in the same physical memory
# pool, doubling peak memory and causing OOM well before the hardware
# limit is reached.
# See: https://github.com/Comfy-Org/ComfyUI/issues/10896
# https://github.com/safetensors/safetensors/issues/758
# https://github.com/safetensors/safetensors/pull/759
#
# This is a temporary workaround until upstream safetensors exposes a
# public ``mmap=False`` option. Here we parse the safetensors header
# ourselves and read each tensor straight from disk into a per-tensor
# ``bytearray`` via ``readinto``, then zero-copy-wrap it as a torch tensor
# with ``torch.frombuffer``. Peak memory is one model copy (plus, if a
# non-CPU device is requested, the bytes of a single tensor in flight
# while it is being moved).
if device is None:
device = torch.device("cpu")
sd = {}
metadata = None
with open(ckpt, "rb") as f:
header_bytes = f.read(8)
if len(header_bytes) != 8:
raise ValueError("HeaderTooLarge: file is too small to be a valid safetensors file: {}".format(ckpt))
header_size = struct.unpack("<Q", header_bytes)[0]
header_data = f.read(header_size)
if len(header_data) != header_size:
raise ValueError("MetadataIncompleteBuffer: truncated header in {}".format(ckpt))
header = json.loads(header_data.decode("utf-8"))
data_base_offset = 8 + header_size
if return_metadata:
metadata = header.get("__metadata__", {})
for name, info in header.items():
if name == "__metadata__":
continue
dtype = _TYPES[info["dtype"]]
shape = info["shape"]
start, end = info["data_offsets"]
num_bytes = end - start
if num_bytes == 0:
tensor = torch.empty(shape, dtype=dtype)
else:
buf = bytearray(num_bytes)
f.seek(data_base_offset + start)
view = memoryview(buf)
offset = 0
while offset < num_bytes:
n = f.readinto(view[offset:])
if not n:
raise ValueError("MetadataIncompleteBuffer: unexpected EOF reading tensor {!r} from {}".format(name, ckpt))
offset += n
tensor = torch.frombuffer(buf, dtype=dtype).reshape(shape)
if device.type != "cpu":
tensor = tensor.to(device=device)
sd[name] = tensor
return sd, metadata
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
@ -129,14 +199,15 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd, metadata = load_safetensors(ckpt)
if not return_metadata:
metadata = None
elif DISABLE_MMAP:
sd, metadata = load_safetensors_no_mmap(ckpt, device=device, return_metadata=return_metadata)
if not return_metadata:
metadata = None
else:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
sd[k] = f.get_tensor(k)
if return_metadata:
metadata = f.metadata()
except Exception as e: