mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-24 16:07:30 +08:00
Merge pull request #14052 from rattus128/prs/worksplit-t-load-fix
fixup threaded loader with worksplit multi-gpu
This commit is contained in:
commit
cb83c41db7
@ -1,6 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
import ctypes
|
import ctypes
|
||||||
import threading
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import torch
|
import torch
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor
|
|||||||
|
|
||||||
class TensorFileSlice(NamedTuple):
|
class TensorFileSlice(NamedTuple):
|
||||||
file_ref: object
|
file_ref: object
|
||||||
thread_id: int
|
lock: object
|
||||||
offset: int
|
offset: int
|
||||||
size: int
|
size: int
|
||||||
|
|
||||||
@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
|||||||
file_obj = info.file_ref
|
file_obj = info.file_ref
|
||||||
if (destination.device.type != "cpu"
|
if (destination.device.type != "cpu"
|
||||||
or file_obj is None
|
or file_obj is None
|
||||||
or threading.get_ident() != info.thread_id
|
|
||||||
or destination.numel() * destination.element_size() < info.size
|
or destination.numel() * destination.element_size() < info.size
|
||||||
or tensor.numel() * tensor.element_size() != info.size
|
or tensor.numel() * tensor.element_size() != info.size
|
||||||
or tensor.storage_offset() != 0
|
or tensor.storage_offset() != 0
|
||||||
@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
|||||||
if hostbuf is not None:
|
if hostbuf is not None:
|
||||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||||
device_ptr = destination2.data_ptr() if destination2 is not None else 0
|
device_ptr = destination2.data_ptr() if destination2 is not None else 0
|
||||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
with info.lock:
|
||||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||||
stream=stream_ptr,
|
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||||
device_ptr=device_ptr,
|
stream=stream_ptr,
|
||||||
device=None if destination2 is None else destination2.device.index)
|
device_ptr=device_ptr,
|
||||||
|
device=None if destination2 is None else destination2.device.index)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
buf_type = ctypes.c_ubyte * info.size
|
buf_type = ctypes.c_ubyte * info.size
|
||||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
file_obj.seek(info.offset)
|
with info.lock:
|
||||||
done = 0
|
file_obj.seek(info.offset)
|
||||||
while done < info.size:
|
done = 0
|
||||||
try:
|
while done < info.size:
|
||||||
n = file_obj.readinto(view[done:])
|
try:
|
||||||
except OSError:
|
n = file_obj.readinto(view[done:])
|
||||||
return False
|
except OSError:
|
||||||
if n <= 0:
|
return False
|
||||||
return False
|
if n <= 0:
|
||||||
done += n
|
return False
|
||||||
|
done += n
|
||||||
return True
|
return True
|
||||||
finally:
|
finally:
|
||||||
view.release()
|
view.release()
|
||||||
|
|||||||
@ -86,6 +86,7 @@ def load_safetensors(ckpt):
|
|||||||
import comfy_aimdo.model_mmap
|
import comfy_aimdo.model_mmap
|
||||||
|
|
||||||
f = open(ckpt, "rb", buffering=0)
|
f = open(ckpt, "rb", buffering=0)
|
||||||
|
file_lock = threading.Lock()
|
||||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||||
file_size = os.path.getsize(ckpt)
|
file_size = os.path.getsize(ckpt)
|
||||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||||
@ -111,7 +112,7 @@ def load_safetensors(ckpt):
|
|||||||
storage = tensor.untyped_storage()
|
storage = tensor.untyped_storage()
|
||||||
setattr(storage,
|
setattr(storage,
|
||||||
"_comfy_tensor_file_slice",
|
"_comfy_tensor_file_slice",
|
||||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
|
||||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||||
sd[name] = tensor
|
sd[name] = tensor
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
|
|||||||
filelock
|
filelock
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.8
|
comfy-kitchen>=0.2.8
|
||||||
comfy-aimdo==0.4.3
|
comfy-aimdo==0.4.4
|
||||||
requests
|
requests
|
||||||
simpleeval>=1.0.0
|
simpleeval>=1.0.0
|
||||||
blake3
|
blake3
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user