implement lightweight safetensors with READ mmap

The CoW MMAP as used by safetensors is hardcoded to CoW which forcibly
consumes windows commit charge on a zero copy. RIP. Implement safetensors
in pytorch itself with a READ mmap to not get commit charged for all our
open models.
This commit is contained in:
Rattus 2026-01-18 22:00:50 +10:00
parent c1a9b4d565
commit 876b886f2a

View File

@ -28,8 +28,12 @@ import logging
import itertools
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
from comfy.cli_args import args, enables_dynamic_vram
import json
import mmap
import ctypes
import packaging
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@ -55,21 +59,72 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
# Current as of safetensors 0.7.0
_TYPES = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
"C64": torch.complex64,
}
if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"):
_TYPES.update(
{
"U64": torch.uint64,
"U32": torch.uint32,
"U16": torch.uint16,
}
)
def load_safetensors(ckpt):
f = open(ckpt, "rb")
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
header_size = struct.unpack("<Q", mapping[:8])[0]
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
data_area = torch.frombuffer(mapping, dtype=torch.uint8)[8 + header_size:]
sd = {}
for name, info in header.items():
if name == "__metadata__": continue
start, end = info["data_offsets"]
sd[name] = data_area[start:end].view(_TYPES[info["dtype"]]).view(info["shape"])
return sd, header.get("__metadata__", {}),
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
else:
assert False
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
if return_metadata:
metadata = f.metadata()
if enables_dynamic_vram():
sd, metadata = load_safetensors(ckpt)
if not return_metadata:
metadata = None
else:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
if return_metadata:
metadata = f.metadata()
except Exception as e:
if len(e.args) > 0:
message = e.args[0]