utils: use the aimdo mmap to load sft files

This commit is contained in:
Rattus 2026-03-11 19:48:40 +10:00
parent b487aacc84
commit d8b4016d67

View File

@ -20,6 +20,8 @@
import torch import torch
import math import math
import struct import struct
import ctypes
import os
import comfy.memory_management import comfy.memory_management
import safetensors.torch import safetensors.torch
import numpy as np import numpy as np
@ -32,7 +34,6 @@ from einops import rearrange
from comfy.cli_args import args from comfy.cli_args import args
import json import json
import time import time
import mmap
import threading import threading
import warnings import warnings
@ -82,12 +83,15 @@ _TYPES = {
} }
def load_safetensors(ckpt): def load_safetensors(ckpt):
f = open(ckpt, "rb", buffering=0) import comfy_aimdo.model_mmap
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
mv = memoryview(mapping)
header_size = struct.unpack("<Q", mapping[:8])[0] f = open(ckpt, "rb", buffering=0)
header = json.loads(mapping[8:8+header_size].decode("utf-8")) model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
header_size = struct.unpack("<Q", mv[:8])[0]
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
mv = mv[(data_base_offset := 8 + header_size):] mv = mv[(data_base_offset := 8 + header_size):]
@ -108,6 +112,7 @@ def load_safetensors(ckpt):
setattr(storage, setattr(storage,
"_comfy_tensor_file_slice", "_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False) setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor sd[name] = tensor