From d59731cf264d2e6316e04586a50c28bedaf5ac0d Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 12 Mar 2026 18:25:42 +0200 Subject: [PATCH 1/8] Support mxfp8 --- comfy/ops.py | 16 ++++++++++++++++ comfy/quant_ops.py | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index 87b36b5c5..81a97a30b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -801,6 +801,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec orig_shape=(self.out_features, self.in_features), ) + elif self.quant_format == "mxfp8": + # MXFP8: E8M0 block scales stored as uint8 in safetensors + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.uint8) + + if block_scale is None: + raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") + + block_scale = block_scale.view(torch.float8_e8m0fnu) + + params = layout_cls.Params( + scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + elif self.quant_format == "nvfp4": # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 15a4f457b..43c6fd7ce 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -8,6 +8,7 @@ try: QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, TensorCoreNVFP4Layout as _CKNvfp4Layout, + TensorCoreMXFP8Layout, register_layout_op, register_layout_class, get_layout_class, @@ -137,6 +138,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) +register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) QUANT_ALGOS = { "float8_e4m3fn": { @@ -155,6 +157,12 @@ QUANT_ALGOS = { "comfy_tensor_layout": "TensorCoreNVFP4Layout", "group_size": 16, }, + "mxfp8": { + "storage_t": torch.float8_e4m3fn, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreMXFP8Layout", + "group_size": 32, + }, } @@ -169,6 +177,7 @@ __all__ = [ "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreNVFP4Layout", + "TensorCoreMXFP8Layout", "QUANT_ALGOS", "register_layout_op", ] From 1f6691077db580d14db67a1314ed286b81feca03 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:11:21 +0200 Subject: [PATCH 2/8] Guards --- comfy/ops.py | 1 + comfy/quant_ops.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 81a97a30b..8c0aaec55 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -975,6 +975,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ if not fp8_compute: disabled.add("float8_e4m3fn") disabled.add("float8_e5m2") + disabled.add("mxfp8") return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 43c6fd7ce..76e755c5f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -38,6 +38,9 @@ except ImportError as e: class _CKNvfp4Layout: pass + class _CKMxfp8Layout: + pass + def register_layout_class(name, cls): pass @@ -138,7 +141,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) -register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) +if _CK_AVAILABLE: + register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) QUANT_ALGOS = { "float8_e4m3fn": { From b322b577ae45df684205f4de6918a464d308684f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:12:34 +0200 Subject: [PATCH 3/8] Update quant_ops.py --- comfy/quant_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 76e755c5f..d67133386 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -181,7 +181,6 @@ __all__ = [ "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreNVFP4Layout", - "TensorCoreMXFP8Layout", "QUANT_ALGOS", "register_layout_op", ] From 9eceec64d753cb6064b3e6cb6240b9641f879719 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:27:28 +0200 Subject: [PATCH 4/8] Update quant_ops.py --- comfy/quant_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index d67133386..d3b4c5b65 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -38,7 +38,7 @@ except ImportError as e: class _CKNvfp4Layout: pass - class _CKMxfp8Layout: + class TensorCoreMXFP8Layout: pass def register_layout_class(name, cls): From 7220240517476633d7ebc587643ade888528e845 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 12 Mar 2026 23:50:30 +0200 Subject: [PATCH 5/8] Add stochastic rounding for LoRAs --- comfy/float.py | 36 ++++++++++++++++++++++++++++++++++++ comfy/quant_ops.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 88c47cd80..184b3d6d0 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed= output_block[i:i + slice_size].copy_(block) return output_fp4, to_blocked(output_block, flatten=False) + + +def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0): + def roundup(x_val, multiple): + return ((x_val + multiple - 1) // multiple) * multiple + + if pad_32x: + rows, cols = x.shape + padded_rows = roundup(rows, 32) + padded_cols = roundup(cols, 32) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + F8_E4M3_MAX = 448.0 + E8M0_BIAS = 127 + BLOCK_SIZE = 32 + + rows, cols = x.shape + x_blocked = x.reshape(rows, -1, BLOCK_SIZE) + max_abs = torch.amax(torch.abs(x_blocked), dim=-1) + + # E8M0 block scales (power-of-2 exponents) + scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127)) + exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254) + block_scales_e8m0 = exp_biased.to(torch.uint8) + + zero_mask = (max_abs == 0) + block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32) + block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32) + + # Scale per-block then stochastic round + data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols) + output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed) + + block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0) + return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index d3b4c5b65..da730b5a2 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -8,7 +8,7 @@ try: QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, TensorCoreNVFP4Layout as _CKNvfp4Layout, - TensorCoreMXFP8Layout, + TensorCoreMXFP8Layout as _CKMxfp8Layout, register_layout_op, register_layout_class, get_layout_class, @@ -38,7 +38,7 @@ except ImportError as e: class _CKNvfp4Layout: pass - class TensorCoreMXFP8Layout: + class _CKMxfp8Layout: pass def register_layout_class(name, cls): @@ -88,6 +88,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): return qdata, params +class TensorCoreMXFP8Layout(_CKMxfp8Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape + + if stochastic_rounding > 0: + qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding) + else: + qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding) + + params = cls.Params( + scale=block_scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + ) + return qdata, params + + class TensorCoreNVFP4Layout(_CKNvfp4Layout): @classmethod def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): From 322b416245795ed54f814704391ed6b1e26a00fe Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 13 Mar 2026 00:02:39 +0200 Subject: [PATCH 6/8] Better guards --- comfy/quant_ops.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index da730b5a2..42ee08fb2 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -8,7 +8,6 @@ try: QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, TensorCoreNVFP4Layout as _CKNvfp4Layout, - TensorCoreMXFP8Layout as _CKMxfp8Layout, register_layout_op, register_layout_class, get_layout_class, @@ -38,15 +37,24 @@ except ImportError as e: class _CKNvfp4Layout: pass - class _CKMxfp8Layout: - pass - def register_layout_class(name, cls): pass def get_layout_class(name): return None +_CK_MXFP8_AVAILABLE = False +if _CK_AVAILABLE: + try: + from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout + _CK_MXFP8_AVAILABLE = True + except ImportError: + logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.") + +if not _CK_MXFP8_AVAILABLE: + class _CKMxfp8Layout: + pass + import comfy.float # ============================================================================== @@ -166,7 +174,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) -if _CK_AVAILABLE: +if _CK_MXFP8_AVAILABLE: register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) QUANT_ALGOS = { @@ -186,13 +194,15 @@ QUANT_ALGOS = { "comfy_tensor_layout": "TensorCoreNVFP4Layout", "group_size": 16, }, - "mxfp8": { +} + +if _CK_MXFP8_AVAILABLE: + QUANT_ALGOS["mxfp8"] = { "storage_t": torch.float8_e4m3fn, "parameters": {"weight_scale", "input_scale"}, "comfy_tensor_layout": "TensorCoreMXFP8Layout", "group_size": 32, - }, -} + } # ============================================================================== From 95126cea6260b1bd2acaa05f3935e6889bbe688b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 13 Mar 2026 01:58:31 +0200 Subject: [PATCH 7/8] Disable on non-Blackwell --- comfy/model_management.py | 10 ++++++++++ comfy/ops.py | 4 +++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 81c89b180..e7d327c74 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1662,6 +1662,16 @@ def supports_nvfp4_compute(device=None): return True +def supports_mxfp8_compute(device=None): + if not is_nvidia(): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/ops.py b/comfy/ops.py index 8c0aaec55..6306fd1be 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -966,16 +966,18 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) + mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") disabled = set() if not nvfp4_compute: disabled.add("nvfp4") + if not mxfp8_compute: + disabled.add("mxfp8") if not fp8_compute: disabled.add("float8_e4m3fn") disabled.add("float8_e5m2") - disabled.add("mxfp8") return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( From 26da3c7d0dd2460ffbb5f583cf508da527ba4e64 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 13 Mar 2026 19:22:37 +0200 Subject: [PATCH 8/8] Add guard for torch version (requires 2.10) --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index e7d327c74..918ad2c10 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1666,6 +1666,9 @@ def supports_mxfp8_compute(device=None): if not is_nvidia(): return False + if torch_version_numeric < (2, 10): + return False + props = torch.cuda.get_device_properties(device) if props.major < 10: return False