mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
Allow pinning quantized tensors. (#10873)
This commit is contained in:
parent
25022e0b09
commit
b6805429b9
@ -1098,13 +1098,14 @@ if not args.disable_pinned_memory:
|
|||||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||||
|
|
||||||
|
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||||
|
|
||||||
def pin_memory(tensor):
|
def pin_memory(tensor):
|
||||||
global TOTAL_PINNED_MEMORY
|
global TOTAL_PINNED_MEMORY
|
||||||
if MAX_PINNED_MEMORY <= 0:
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if type(tensor) is not torch.nn.parameter.Parameter:
|
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not is_device_cpu(tensor.device):
|
if not is_device_cpu(tensor.device):
|
||||||
@ -1124,6 +1125,9 @@ def pin_memory(tensor):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
ptr = tensor.data_ptr()
|
ptr = tensor.data_ptr()
|
||||||
|
if ptr == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
||||||
PINNED_MEMORY[ptr] = size
|
PINNED_MEMORY[ptr] = size
|
||||||
TOTAL_PINNED_MEMORY += size
|
TOTAL_PINNED_MEMORY += size
|
||||||
|
|||||||
@ -228,6 +228,14 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
new_kwargs = dequant_arg(kwargs)
|
new_kwargs = dequant_arg(kwargs)
|
||||||
return func(*new_args, **new_kwargs)
|
return func(*new_args, **new_kwargs)
|
||||||
|
|
||||||
|
def data_ptr(self):
|
||||||
|
return self._qdata.data_ptr()
|
||||||
|
|
||||||
|
def is_pinned(self):
|
||||||
|
return self._qdata.is_pinned()
|
||||||
|
|
||||||
|
def is_contiguous(self):
|
||||||
|
return self._qdata.is_contiguous()
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Generic Utilities (Layout-Agnostic Operations)
|
# Generic Utilities (Layout-Agnostic Operations)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user