mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-24 16:07:30 +08:00
Merge pull request #14052 from rattus128/prs/worksplit-t-load-fix
fixup threaded loader with worksplit multi-gpu
This commit is contained in:
commit
cb83c41db7
@ -1,6 +1,5 @@
|
||||
import math
|
||||
import ctypes
|
||||
import threading
|
||||
import dataclasses
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
class TensorFileSlice(NamedTuple):
|
||||
file_ref: object
|
||||
thread_id: int
|
||||
lock: object
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
file_obj = info.file_ref
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
if hostbuf is not None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
device_ptr = destination2.data_ptr() if destination2 is not None else 0
|
||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||
stream=stream_ptr,
|
||||
device_ptr=device_ptr,
|
||||
device=None if destination2 is None else destination2.device.index)
|
||||
with info.lock:
|
||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||
stream=stream_ptr,
|
||||
device_ptr=device_ptr,
|
||||
device=None if destination2 is None else destination2.device.index)
|
||||
return True
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
try:
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
with info.lock:
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
return True
|
||||
finally:
|
||||
view.release()
|
||||
|
||||
@ -86,6 +86,7 @@ def load_safetensors(ckpt):
|
||||
import comfy_aimdo.model_mmap
|
||||
|
||||
f = open(ckpt, "rb", buffering=0)
|
||||
file_lock = threading.Lock()
|
||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||
file_size = os.path.getsize(ckpt)
|
||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||
@ -111,7 +112,7 @@ def load_safetensors(ckpt):
|
||||
storage = tensor.untyped_storage()
|
||||
setattr(storage,
|
||||
"_comfy_tensor_file_slice",
|
||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||
comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
|
||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||
sd[name] = tensor
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-aimdo==0.4.3
|
||||
comfy-aimdo==0.4.4
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
Loading…
Reference in New Issue
Block a user