diff --git a/comfy/utils.py b/comfy/utils.py index 635dc4146..8bc7d7a96 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,8 +28,12 @@ import logging import itertools from torch.nn.functional import interpolate from einops import rearrange -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram import json +import mmap +import ctypes + +import packaging MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -55,21 +59,72 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in else: logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") +# Current as of safetensors 0.7.0 +_TYPES = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + "C64": torch.complex64, +} +if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"): + _TYPES.update( + { + "U64": torch.uint64, + "U32": torch.uint32, + "U16": torch.uint16, + } + ) + +def load_safetensors(ckpt): + f = open(ckpt, "rb") + mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + header_size = struct.unpack(" 0: message = e.args[0]