From d59731cf264d2e6316e04586a50c28bedaf5ac0d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 12 Mar 2026 18:25:42 +0200 Subject: [PATCH] Support mxfp8 --- comfy/ops.py | 16 ++++++++++++++++ comfy/quant_ops.py | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index 87b36b5c5..81a97a30b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -801,6 +801,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec orig_shape=(self.out_features, self.in_features), ) + elif self.quant_format == "mxfp8": + # MXFP8: E8M0 block scales stored as uint8 in safetensors + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.uint8) + + if block_scale is None: + raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") + + block_scale = block_scale.view(torch.float8_e8m0fnu) + + params = layout_cls.Params( + scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + elif self.quant_format == "nvfp4": # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 15a4f457b..43c6fd7ce 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -8,6 +8,7 @@ try: QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, TensorCoreNVFP4Layout as _CKNvfp4Layout, + TensorCoreMXFP8Layout, register_layout_op, register_layout_class, get_layout_class, @@ -137,6 +138,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) +register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) QUANT_ALGOS = { "float8_e4m3fn": { @@ -155,6 +157,12 @@ QUANT_ALGOS = { "comfy_tensor_layout": "TensorCoreNVFP4Layout", "group_size": 16, }, + "mxfp8": { + "storage_t": torch.float8_e4m3fn, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreMXFP8Layout", + "group_size": 32, + }, } @@ -169,6 +177,7 @@ __all__ = [ "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreNVFP4Layout", + "TensorCoreMXFP8Layout", "QUANT_ALGOS", "register_layout_op", ]