diff --git a/comfy/model_management.py b/comfy/model_management.py index 81c89b180..e7d327c74 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1662,6 +1662,16 @@ def supports_nvfp4_compute(device=None): return True +def supports_mxfp8_compute(device=None): + if not is_nvidia(): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/ops.py b/comfy/ops.py index 8c0aaec55..6306fd1be 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -966,16 +966,18 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) + mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") disabled = set() if not nvfp4_compute: disabled.add("nvfp4") + if not mxfp8_compute: + disabled.add("mxfp8") if not fp8_compute: disabled.add("float8_e4m3fn") disabled.add("float8_e5m2") - disabled.add("mxfp8") return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if (