Skip fp4 matrix mult on devices that don't support it. (#11677)

This commit is contained in:
comfyanonymous 2026-01-06 15:07:26 -08:00 committed by GitHub
parent 6e9ee55cdd
commit 2c03884f5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 4 deletions

View File

@ -1504,6 +1504,16 @@ def supports_fp8_compute(device=None):
return True
def supports_nvfp4_compute(device=None):
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):

View File

@ -493,11 +493,12 @@ from .quant_ops import (
)
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
_disabled = disabled
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
@ -522,6 +523,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False
def reset_parameters(self):
return None
@ -556,8 +558,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not self._full_precision_mm:
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
self._full_precision_mm = self._full_precision_mm_config
if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True
if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
@ -630,7 +636,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
return sd
@ -711,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
if (
fp8_compute and