diff --git a/comfy/memory_management.py b/comfy/memory_management.py index c43f0c4a2..962addb27 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,6 +1,5 @@ import math import ctypes -import threading import dataclasses import torch from typing import NamedTuple @@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor class TensorFileSlice(NamedTuple): file_ref: object - thread_id: int + lock: object offset: int size: int @@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N file_obj = info.file_ref if (destination.device.type != "cpu" or file_obj is None - or threading.get_ident() != info.thread_id or destination.numel() * destination.element_size() < info.size or tensor.numel() * tensor.element_size() != info.size or tensor.storage_offset() != 0 @@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N if hostbuf is not None: stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 device_ptr = destination2.data_ptr() if destination2 is not None else 0 - hostbuf.read_file_slice(file_obj, info.offset, info.size, - offset=destination.data_ptr() - hostbuf.get_raw_address(), - stream=stream_ptr, - device_ptr=device_ptr, - device=None if destination2 is None else destination2.device.index) + with info.lock: + hostbuf.read_file_slice(file_obj, info.offset, info.size, + offset=destination.data_ptr() - hostbuf.get_raw_address(), + stream=stream_ptr, + device_ptr=device_ptr, + device=None if destination2 is None else destination2.device.index) return True buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) try: - file_obj.seek(info.offset) - done = 0 - while done < info.size: - try: - n = file_obj.readinto(view[done:]) - except OSError: - return False - if n <= 0: - return False - done += n + with info.lock: + file_obj.seek(info.offset) + done = 0 + while done < info.size: + try: + n = file_obj.readinto(view[done:]) + except OSError: + return False + if n <= 0: + return False + done += n return True finally: view.release() diff --git a/comfy/utils.py b/comfy/utils.py index 6b12676d2..abfd4079d 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -86,6 +86,7 @@ def load_safetensors(ckpt): import comfy_aimdo.model_mmap f = open(ckpt, "rb", buffering=0) + file_lock = threading.Lock() model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) file_size = os.path.getsize(ckpt) mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) @@ -111,7 +112,7 @@ def load_safetensors(ckpt): storage = tensor.untyped_storage() setattr(storage, "_comfy_tensor_file_slice", - comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) + comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) sd[name] = tensor diff --git a/requirements.txt b/requirements.txt index e20b6e044..381e7d05f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.4.3 +comfy-aimdo==0.4.4 requests simpleeval>=1.0.0 blake3