diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 66df35244..df23e4497 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -38,7 +38,10 @@ def read_tensor_file_slice_into(tensor, destination): file_obj.seek(info.offset) done = 0 while done < info.size: - n = file_obj.readinto(view[done:]) + try: + n = file_obj.readinto(view[done:]) + except OSError: + return False if n <= 0: return False done += n diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 8acc327a7..f6fb806c4 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -1,6 +1,7 @@ -import torch import comfy.model_management import comfy.memory_management +import comfy_aimdo.host_buffer +import comfy_aimdo.torch from comfy.cli_args import args @@ -12,18 +13,31 @@ def pin_memory(module): return #FIXME: This is a RAM cache trigger event size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) - pin = torch.empty((size,), dtype=torch.uint8) - if comfy.model_management.pin_memory(pin): - module._pin = pin - else: + + if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: module.pin_failed = True return False + + try: + hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) + except RuntimeError: + module.pin_failed = True + return False + + module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) + module._pin_hostbuf = hostbuf + comfy.model_management.TOTAL_PINNED_MEMORY += size return True def unpin_memory(module): if get_pin(module) is None: return 0 size = module._pin.numel() * module._pin.element_size() - comfy.model_management.unpin_memory(module._pin) + + comfy.model_management.TOTAL_PINNED_MEMORY -= size + if comfy.model_management.TOTAL_PINNED_MEMORY < 0: + comfy.model_management.TOTAL_PINNED_MEMORY = 0 + del module._pin + del module._pin_hostbuf return size