Allow pinning quantized tensors. (#10873)

This commit is contained in:
comfyanonymous 2025-11-24 23:48:20 -08:00 committed by GitHub
parent 25022e0b09
commit b6805429b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 1 deletions

View File

@ -1098,13 +1098,14 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if type(tensor) is not torch.nn.parameter.Parameter:
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
return False
if not is_device_cpu(tensor.device):
@ -1124,6 +1125,9 @@ def pin_memory(tensor):
return False
ptr = tensor.data_ptr()
if ptr == 0:
return False
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size

View File

@ -228,6 +228,14 @@ class QuantizedTensor(torch.Tensor):
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
def data_ptr(self):
return self._qdata.data_ptr()
def is_pinned(self):
return self._qdata.is_pinned()
def is_contiguous(self):
return self._qdata.is_contiguous()
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)