mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-19 22:39:24 +08:00
Implement seek and read for pins
Source pins from an mmap is pad because its its a CPU->CPU copy that attempts to fully buffer the same data twice. Instead, use seek and read which avoids the mmap buffering while usually being a faster read in the first place (avoiding mmap faulting etc).
This commit is contained in:
parent
4a8cf359fe
commit
6fbbcc4cb7
@ -1,9 +1,51 @@
|
|||||||
import math
|
import math
|
||||||
|
import ctypes
|
||||||
|
import threading
|
||||||
import torch
|
import torch
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
import logging
|
||||||
|
|
||||||
from comfy.quant_ops import QuantizedTensor
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
|
||||||
|
class TensorFileSlice(NamedTuple):
|
||||||
|
file_ref: object
|
||||||
|
thread_id: int
|
||||||
|
offset: int
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
def read_tensor_file_slice_into(tensor, destination):
|
||||||
|
|
||||||
|
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||||
|
if info is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
file_obj = info.file_ref
|
||||||
|
if (destination.device.type != "cpu"
|
||||||
|
or file_obj is None
|
||||||
|
or threading.get_ident() != info.thread_id
|
||||||
|
or destination.numel() * destination.element_size() < info.size):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if info.size == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
buf_type = ctypes.c_ubyte * info.size
|
||||||
|
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_obj.seek(info.offset)
|
||||||
|
done = 0
|
||||||
|
while done < info.size:
|
||||||
|
n = file_obj.readinto(view[done:])
|
||||||
|
if n <= 0:
|
||||||
|
return False
|
||||||
|
done += n
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
view.release()
|
||||||
|
|
||||||
class TensorGeometry(NamedTuple):
|
class TensorGeometry(NamedTuple):
|
||||||
shape: any
|
shape: any
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
|||||||
@ -1225,6 +1225,8 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
|||||||
dest_view = dest_views.pop(0)
|
dest_view = dest_views.pop(0)
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
continue
|
continue
|
||||||
|
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||||
|
continue
|
||||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -33,6 +33,7 @@ from comfy.cli_args import args
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import mmap
|
import mmap
|
||||||
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
@ -81,14 +82,14 @@ _TYPES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def load_safetensors(ckpt):
|
def load_safetensors(ckpt):
|
||||||
f = open(ckpt, "rb")
|
f = open(ckpt, "rb", buffering=0)
|
||||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||||
mv = memoryview(mapping)
|
mv = memoryview(mapping)
|
||||||
|
|
||||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||||
|
|
||||||
mv = mv[8 + header_size:]
|
mv = mv[(data_base_offset := 8 + header_size):]
|
||||||
|
|
||||||
sd = {}
|
sd = {}
|
||||||
for name, info in header.items():
|
for name, info in header.items():
|
||||||
@ -102,7 +103,11 @@ def load_safetensors(ckpt):
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
#We are working with read-only RAM by design
|
#We are working with read-only RAM by design
|
||||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||||
|
setattr(tensor.untyped_storage(),
|
||||||
|
"_comfy_tensor_file_slice",
|
||||||
|
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||||
|
sd[name] = tensor
|
||||||
|
|
||||||
return sd, header.get("__metadata__", {}),
|
return sd, header.get("__metadata__", {}),
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user