diff --git a/comfy/float.py b/comfy/float.py index 88c47cd80..184b3d6d0 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed= output_block[i:i + slice_size].copy_(block) return output_fp4, to_blocked(output_block, flatten=False) + + +def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0): + def roundup(x_val, multiple): + return ((x_val + multiple - 1) // multiple) * multiple + + if pad_32x: + rows, cols = x.shape + padded_rows = roundup(rows, 32) + padded_cols = roundup(cols, 32) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + F8_E4M3_MAX = 448.0 + E8M0_BIAS = 127 + BLOCK_SIZE = 32 + + rows, cols = x.shape + x_blocked = x.reshape(rows, -1, BLOCK_SIZE) + max_abs = torch.amax(torch.abs(x_blocked), dim=-1) + + # E8M0 block scales (power-of-2 exponents) + scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127)) + exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254) + block_scales_e8m0 = exp_biased.to(torch.uint8) + + zero_mask = (max_abs == 0) + block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32) + block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32) + + # Scale per-block then stochastic round + data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols) + output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed) + + block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0) + return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index d3b4c5b65..da730b5a2 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -8,7 +8,7 @@ try: QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, TensorCoreNVFP4Layout as _CKNvfp4Layout, - TensorCoreMXFP8Layout, + TensorCoreMXFP8Layout as _CKMxfp8Layout, register_layout_op, register_layout_class, get_layout_class, @@ -38,7 +38,7 @@ except ImportError as e: class _CKNvfp4Layout: pass - class TensorCoreMXFP8Layout: + class _CKMxfp8Layout: pass def register_layout_class(name, cls): @@ -88,6 +88,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): return qdata, params +class TensorCoreMXFP8Layout(_CKMxfp8Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape + + if stochastic_rounding > 0: + qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding) + else: + qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding) + + params = cls.Params( + scale=block_scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + ) + return qdata, params + + class TensorCoreNVFP4Layout(_CKNvfp4Layout): @classmethod def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):