Limit amount of pinned memory on windows to prevent issues. (#10638)

This commit is contained in:
comfyanonymous 2025-11-04 14:37:50 -08:00 committed by GitHub
parent a389ee01bb
commit 7f3e4d486c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1082,8 +1082,20 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device) non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
if PerformanceFeature.PinnedMem in args.fast:
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
else:
MAX_PINNED_MEMORY = -1
def pin_memory(tensor): def pin_memory(tensor):
if PerformanceFeature.PinnedMem not in args.fast: global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False return False
if not is_nvidia(): if not is_nvidia():
@ -1092,13 +1104,21 @@ def pin_memory(tensor):
if not is_device_cpu(tensor.device): if not is_device_cpu(tensor.device):
return False return False
if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0: size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
ptr = tensor.data_ptr()
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True return True
return False return False
def unpin_memory(tensor): def unpin_memory(tensor):
if PerformanceFeature.PinnedMem not in args.fast: global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False return False
if not is_nvidia(): if not is_nvidia():
@ -1107,7 +1127,11 @@ def unpin_memory(tensor):
if not is_device_cpu(tensor.device): if not is_device_cpu(tensor.device):
return False return False
if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0: ptr = tensor.data_ptr()
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True return True
return False return False