diff --git a/comfy/float.py b/comfy/float.py index 88c47cd80..184b3d6d0 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed= output_block[i:i + slice_size].copy_(block) return output_fp4, to_blocked(output_block, flatten=False) + + +def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0): + def roundup(x_val, multiple): + return ((x_val + multiple - 1) // multiple) * multiple + + if pad_32x: + rows, cols = x.shape + padded_rows = roundup(rows, 32) + padded_cols = roundup(cols, 32) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + F8_E4M3_MAX = 448.0 + E8M0_BIAS = 127 + BLOCK_SIZE = 32 + + rows, cols = x.shape + x_blocked = x.reshape(rows, -1, BLOCK_SIZE) + max_abs = torch.amax(torch.abs(x_blocked), dim=-1) + + # E8M0 block scales (power-of-2 exponents) + scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127)) + exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254) + block_scales_e8m0 = exp_biased.to(torch.uint8) + + zero_mask = (max_abs == 0) + block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32) + block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32) + + # Scale per-block then stochastic round + data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols) + output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed) + + block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0) + return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu) diff --git a/comfy/model_management.py b/comfy/model_management.py index 4d5851bc0..bb77cff47 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1712,6 +1712,19 @@ def supports_nvfp4_compute(device=None): return True +def supports_mxfp8_compute(device=None): + if not is_nvidia(): + return False + + if torch_version_numeric < (2, 10): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/ops.py b/comfy/ops.py index 3f2da4e63..59c0df87d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -857,6 +857,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec orig_shape=(self.out_features, self.in_features), ) + elif self.quant_format == "mxfp8": + # MXFP8: E8M0 block scales stored as uint8 in safetensors + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.uint8) + + if block_scale is None: + raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") + + block_scale = block_scale.view(torch.float8_e8m0fnu) + + params = layout_cls.Params( + scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + elif self.quant_format == "nvfp4": # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) @@ -1006,12 +1022,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) + mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") disabled = set() if not nvfp4_compute: disabled.add("nvfp4") + if not mxfp8_compute: + disabled.add("mxfp8") if not fp8_compute: disabled.add("float8_e4m3fn") disabled.add("float8_e5m2") diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 15a4f457b..42ee08fb2 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -43,6 +43,18 @@ except ImportError as e: def get_layout_class(name): return None +_CK_MXFP8_AVAILABLE = False +if _CK_AVAILABLE: + try: + from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout + _CK_MXFP8_AVAILABLE = True + except ImportError: + logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.") + +if not _CK_MXFP8_AVAILABLE: + class _CKMxfp8Layout: + pass + import comfy.float # ============================================================================== @@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): return qdata, params +class TensorCoreMXFP8Layout(_CKMxfp8Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape + + if stochastic_rounding > 0: + qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding) + else: + qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding) + + params = cls.Params( + scale=block_scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + ) + return qdata, params + + class TensorCoreNVFP4Layout(_CKNvfp4Layout): @classmethod def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): @@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) +if _CK_MXFP8_AVAILABLE: + register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) QUANT_ALGOS = { "float8_e4m3fn": { @@ -157,6 +196,14 @@ QUANT_ALGOS = { }, } +if _CK_MXFP8_AVAILABLE: + QUANT_ALGOS["mxfp8"] = { + "storage_t": torch.float8_e4m3fn, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreMXFP8Layout", + "group_size": 32, + } + # ============================================================================== # Re-exports for backward compatibility