feat: Support mxfp8 (#12907)

This commit is contained in:
Jukka Seppänen 2026-03-15 00:36:29 +02:00 committed by GitHub
parent e0982a7174
commit 1c5db7397d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 115 additions and 0 deletions

View File

@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False)
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
def roundup(x_val, multiple):
return ((x_val + multiple - 1) // multiple) * multiple
if pad_32x:
rows, cols = x.shape
padded_rows = roundup(rows, 32)
padded_cols = roundup(cols, 32)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
F8_E4M3_MAX = 448.0
E8M0_BIAS = 127
BLOCK_SIZE = 32
rows, cols = x.shape
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
# E8M0 block scales (power-of-2 exponents)
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
block_scales_e8m0 = exp_biased.to(torch.uint8)
zero_mask = (max_abs == 0)
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
# Scale per-block then stochastic round
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)

View File

@ -1712,6 +1712,19 @@ def supports_nvfp4_compute(device=None):
return True
def supports_mxfp8_compute(device=None):
if not is_nvidia():
return False
if torch_version_numeric < (2, 10):
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):

View File

@ -857,6 +857,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
@ -1006,12 +1022,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not mxfp8_compute:
disabled.add("mxfp8")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")

View File

@ -43,6 +43,18 @@ except ImportError as e:
def get_layout_class(name):
return None
_CK_MXFP8_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
_CK_MXFP8_AVAILABLE = True
except ImportError:
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
import comfy.float
# ==============================================================================
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
params = cls.Params(
scale=block_scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
@ -157,6 +196,14 @@ QUANT_ALGOS = {
},
}
if _CK_MXFP8_AVAILABLE:
QUANT_ALGOS["mxfp8"] = {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
"group_size": 32,
}
# ==============================================================================
# Re-exports for backward compatibility