mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Merge 1ff364873a into 65045730a6
This commit is contained in:
commit
c57ee6a4dd
@ -119,6 +119,76 @@ def load_safetensors(ckpt):
|
|||||||
return sd, header.get("__metadata__", {}),
|
return sd, header.get("__metadata__", {}),
|
||||||
|
|
||||||
|
|
||||||
|
def load_safetensors_no_mmap(ckpt, device=None, return_metadata=False):
|
||||||
|
# Load a .safetensors / .sft file without ever mmap'ing it.
|
||||||
|
#
|
||||||
|
# safetensors.safe_open() (and therefore safetensors.torch.load_file) always
|
||||||
|
# mmaps the underlying file in Rust. On systems with unified CPU/GPU memory
|
||||||
|
# like NVIDIA Grace Blackwell / DGX Spark, Apple Silicon, AMD APUs, etc.
|
||||||
|
# this is fatal for large models: the OS page-cache pages backing the mmap
|
||||||
|
# and any subsequent device copy both reside in the same physical memory
|
||||||
|
# pool, doubling peak memory and causing OOM well before the hardware
|
||||||
|
# limit is reached.
|
||||||
|
# See: https://github.com/Comfy-Org/ComfyUI/issues/10896
|
||||||
|
# https://github.com/safetensors/safetensors/issues/758
|
||||||
|
# https://github.com/safetensors/safetensors/pull/759
|
||||||
|
#
|
||||||
|
# This is a temporary workaround until upstream safetensors exposes a
|
||||||
|
# public ``mmap=False`` option. Here we parse the safetensors header
|
||||||
|
# ourselves and read each tensor straight from disk into a per-tensor
|
||||||
|
# ``bytearray`` via ``readinto``, then zero-copy-wrap it as a torch tensor
|
||||||
|
# with ``torch.frombuffer``. Peak memory is one model copy (plus, if a
|
||||||
|
# non-CPU device is requested, the bytes of a single tensor in flight
|
||||||
|
# while it is being moved).
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
sd = {}
|
||||||
|
metadata = None
|
||||||
|
with open(ckpt, "rb") as f:
|
||||||
|
header_bytes = f.read(8)
|
||||||
|
if len(header_bytes) != 8:
|
||||||
|
raise ValueError("HeaderTooLarge: file is too small to be a valid safetensors file: {}".format(ckpt))
|
||||||
|
header_size = struct.unpack("<Q", header_bytes)[0]
|
||||||
|
header_data = f.read(header_size)
|
||||||
|
if len(header_data) != header_size:
|
||||||
|
raise ValueError("MetadataIncompleteBuffer: truncated header in {}".format(ckpt))
|
||||||
|
header = json.loads(header_data.decode("utf-8"))
|
||||||
|
data_base_offset = 8 + header_size
|
||||||
|
|
||||||
|
if return_metadata:
|
||||||
|
metadata = header.get("__metadata__", {})
|
||||||
|
|
||||||
|
for name, info in header.items():
|
||||||
|
if name == "__metadata__":
|
||||||
|
continue
|
||||||
|
|
||||||
|
dtype = _TYPES[info["dtype"]]
|
||||||
|
shape = info["shape"]
|
||||||
|
start, end = info["data_offsets"]
|
||||||
|
num_bytes = end - start
|
||||||
|
|
||||||
|
if num_bytes == 0:
|
||||||
|
tensor = torch.empty(shape, dtype=dtype)
|
||||||
|
else:
|
||||||
|
buf = bytearray(num_bytes)
|
||||||
|
f.seek(data_base_offset + start)
|
||||||
|
view = memoryview(buf)
|
||||||
|
offset = 0
|
||||||
|
while offset < num_bytes:
|
||||||
|
n = f.readinto(view[offset:])
|
||||||
|
if not n:
|
||||||
|
raise ValueError("MetadataIncompleteBuffer: unexpected EOF reading tensor {!r} from {}".format(name, ckpt))
|
||||||
|
offset += n
|
||||||
|
tensor = torch.frombuffer(buf, dtype=dtype).reshape(shape)
|
||||||
|
|
||||||
|
if device.type != "cpu":
|
||||||
|
tensor = tensor.to(device=device)
|
||||||
|
sd[name] = tensor
|
||||||
|
|
||||||
|
return sd, metadata
|
||||||
|
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -129,14 +199,15 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
sd, metadata = load_safetensors(ckpt)
|
sd, metadata = load_safetensors(ckpt)
|
||||||
if not return_metadata:
|
if not return_metadata:
|
||||||
metadata = None
|
metadata = None
|
||||||
|
elif DISABLE_MMAP:
|
||||||
|
sd, metadata = load_safetensors_no_mmap(ckpt, device=device, return_metadata=return_metadata)
|
||||||
|
if not return_metadata:
|
||||||
|
metadata = None
|
||||||
else:
|
else:
|
||||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||||
sd = {}
|
sd = {}
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
tensor = f.get_tensor(k)
|
sd[k] = f.get_tensor(k)
|
||||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
|
||||||
tensor = tensor.to(device=device, copy=True)
|
|
||||||
sd[k] = tensor
|
|
||||||
if return_metadata:
|
if return_metadata:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user