diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 0b7da2852..66df35244 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,9 +1,51 @@ import math +import ctypes +import threading import torch from typing import NamedTuple +import logging from comfy.quant_ops import QuantizedTensor + +class TensorFileSlice(NamedTuple): + file_ref: object + thread_id: int + offset: int + size: int + + +def read_tensor_file_slice_into(tensor, destination): + + info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None) + if info is None: + return False + + file_obj = info.file_ref + if (destination.device.type != "cpu" + or file_obj is None + or threading.get_ident() != info.thread_id + or destination.numel() * destination.element_size() < info.size): + return False + + if info.size == 0: + return True + + buf_type = ctypes.c_ubyte * info.size + view = memoryview(buf_type.from_address(destination.data_ptr())) + + try: + file_obj.seek(info.offset) + done = 0 + while done < info.size: + n = file_obj.readinto(view[done:]) + if n <= 0: + return False + done += n + return True + finally: + view.release() + class TensorGeometry(NamedTuple): shape: any dtype: torch.dtype diff --git a/comfy/model_management.py b/comfy/model_management.py index 81c89b180..813b927be 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1225,6 +1225,8 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): dest_view = dest_views.pop(0) if tensor is None: continue + if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): + continue dest_view.copy_(tensor, non_blocking=non_blocking) diff --git a/comfy/utils.py b/comfy/utils.py index 6e1d14419..c30366f8e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -33,6 +33,7 @@ from comfy.cli_args import args import json import time import mmap +import threading import warnings MMAP_TORCH_FILES = args.mmap_torch_files @@ -81,14 +82,14 @@ _TYPES = { } def load_safetensors(ckpt): - f = open(ckpt, "rb") + f = open(ckpt, "rb", buffering=0) mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) mv = memoryview(mapping) header_size = struct.unpack("