From 2c03884f5fb7fa213161dfe1e9a09a8e8c4b6062 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 15:07:26 -0800 Subject: [PATCH] Skip fp4 matrix mult on devices that don't support it. (#11677) --- comfy/model_management.py | 10 ++++++++++ comfy/ops.py | 21 +++++++++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 22f4de044..928282092 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1504,6 +1504,16 @@ def supports_fp8_compute(device=None): return True +def supports_nvfp4_compute(device=None): + if not is_nvidia(): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/ops.py b/comfy/ops.py index f5e1e9230..8f9fdce36 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -493,11 +493,12 @@ from .quant_ops import ( ) -def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm + _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): def __init__( @@ -522,6 +523,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False def reset_parameters(self): return None @@ -556,8 +558,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) + self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) if not self._full_precision_mm: - self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + self._full_precision_mm = self._full_precision_mm_config + + if self.quant_format in MixedPrecisionOps._disabled: + self._full_precision_mm = True if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") @@ -630,7 +636,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale quant_conf = {"format": self.quant_format} - if self._full_precision_mm: + if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) return sd @@ -711,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") - return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) + disabled = set() + if not nvfp4_compute: + disabled.add("nvfp4") + if not fp8_compute: + disabled.add("float8_e4m3fn") + disabled.add("float8_e5m2") + return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( fp8_compute and