Compare commits

..

No commits in common. "015a0599d08f1072155b9213d488b73e502fea3c" and "25022e0b0965975b35bcaf28b153184d60a4f9de" have entirely different histories.

3 changed files with 5 additions and 18 deletions

View File

@ -1098,14 +1098,13 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
if type(tensor) is not torch.nn.parameter.Parameter:
return False
if not is_device_cpu(tensor.device):
@ -1125,9 +1124,6 @@ def pin_memory(tensor):
return False
ptr = tensor.data_ptr()
if ptr == 0:
return False
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size

View File

@ -646,12 +646,11 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype)
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

View File

@ -228,14 +228,6 @@ class QuantizedTensor(torch.Tensor):
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
def data_ptr(self):
return self._qdata.data_ptr()
def is_pinned(self):
return self._qdata.is_pinned()
def is_contiguous(self):
return self._qdata.is_contiguous()
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
@ -405,8 +397,8 @@ class TensorCoreFP8Layout(QuantizedLayout):
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
layout_params = {