Merge pull request #14052 from rattus128/prs/worksplit-t-load-fix

fixup threaded loader with worksplit multi-gpu
This commit is contained in:
Jedrzej Kosinski 2026-05-22 16:36:33 -07:00 committed by GitHub
commit cb83c41db7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 20 deletions

View File

@ -1,6 +1,5 @@
import math import math
import ctypes import ctypes
import threading
import dataclasses import dataclasses
import torch import torch
from typing import NamedTuple from typing import NamedTuple
@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple): class TensorFileSlice(NamedTuple):
file_ref: object file_ref: object
thread_id: int lock: object
offset: int offset: int
size: int size: int
@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
file_obj = info.file_ref file_obj = info.file_ref
if (destination.device.type != "cpu" if (destination.device.type != "cpu"
or file_obj is None or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0 or tensor.storage_offset() != 0
@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if hostbuf is not None: if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
device_ptr = destination2.data_ptr() if destination2 is not None else 0 device_ptr = destination2.data_ptr() if destination2 is not None else 0
hostbuf.read_file_slice(file_obj, info.offset, info.size, with info.lock:
offset=destination.data_ptr() - hostbuf.get_raw_address(), hostbuf.read_file_slice(file_obj, info.offset, info.size,
stream=stream_ptr, offset=destination.data_ptr() - hostbuf.get_raw_address(),
device_ptr=device_ptr, stream=stream_ptr,
device=None if destination2 is None else destination2.device.index) device_ptr=device_ptr,
device=None if destination2 is None else destination2.device.index)
return True return True
buf_type = ctypes.c_ubyte * info.size buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr())) view = memoryview(buf_type.from_address(destination.data_ptr()))
try: try:
file_obj.seek(info.offset) with info.lock:
done = 0 file_obj.seek(info.offset)
while done < info.size: done = 0
try: while done < info.size:
n = file_obj.readinto(view[done:]) try:
except OSError: n = file_obj.readinto(view[done:])
return False except OSError:
if n <= 0: return False
return False if n <= 0:
done += n return False
done += n
return True return True
finally: finally:
view.release() view.release()

View File

@ -86,6 +86,7 @@ def load_safetensors(ckpt):
import comfy_aimdo.model_mmap import comfy_aimdo.model_mmap
f = open(ckpt, "rb", buffering=0) f = open(ckpt, "rb", buffering=0)
file_lock = threading.Lock()
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt) file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
@ -111,7 +112,7 @@ def load_safetensors(ckpt):
storage = tensor.untyped_storage() storage = tensor.untyped_storage()
setattr(storage, setattr(storage,
"_comfy_tensor_file_slice", "_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
sd[name] = tensor sd[name] = tensor

View File

@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
filelock filelock
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.8 comfy-kitchen>=0.2.8
comfy-aimdo==0.4.3 comfy-aimdo==0.4.4
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
blake3 blake3