implement lightweight safetensors with READ mmap

The CoW MMAP as used by safetensors is hardcoded to CoW which forcibly
consumes windows commit charge on a zero copy. RIP. Implement safetensors
in pytorch itself with a READ mmap to not get commit charged for all our
open models.
This commit is contained in:
Rattus 2026-01-18 22:00:50 +10:00
parent c1a9b4d565
commit 876b886f2a

View File

@ -28,8 +28,12 @@ import logging
import itertools import itertools
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from einops import rearrange from einops import rearrange
from comfy.cli_args import args from comfy.cli_args import args, enables_dynamic_vram
import json import json
import mmap
import ctypes
import packaging
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap DISABLE_MMAP = args.disable_mmap
@ -55,21 +59,72 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else: else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
# Current as of safetensors 0.7.0
_TYPES = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
"C64": torch.complex64,
}
if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"):
_TYPES.update(
{
"U64": torch.uint64,
"U32": torch.uint32,
"U16": torch.uint16,
}
)
def load_safetensors(ckpt):
f = open(ckpt, "rb")
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
header_size = struct.unpack("<Q", mapping[:8])[0]
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
data_area = torch.frombuffer(mapping, dtype=torch.uint8)[8 + header_size:]
sd = {}
for name, info in header.items():
if name == "__metadata__": continue
start, end = info["data_offsets"]
sd[name] = data_area[start:end].view(_TYPES[info["dtype"]]).view(info["shape"])
return sd, header.get("__metadata__", {}),
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
else:
assert False
metadata = None metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try: try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: if enables_dynamic_vram():
sd = {} sd, metadata = load_safetensors(ckpt)
for k in f.keys(): if not return_metadata:
tensor = f.get_tensor(k) metadata = None
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues else:
tensor = tensor.to(device=device, copy=True) with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd[k] = tensor sd = {}
if return_metadata: for k in f.keys():
metadata = f.metadata() tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
if return_metadata:
metadata = f.metadata()
except Exception as e: except Exception as e:
if len(e.args) > 0: if len(e.args) > 0:
message = e.args[0] message = e.args[0]