Fix FP8 activation quantization for >2D activations in mixed_precision_ops

mixed_precision_ops.Linear.forward only quantized activations that were 2D, or
3D (reshaped to 2D). Inputs with rank >= 4 (e.g. Anima's MLP activations, which
are not reshaped to 3D the way the attention path is) fell through the
`input_reshaped.ndim == 2` guard and reached scaled_mm as bf16, silently
dispatching a bf16 kernel instead of FP8. Since MLP is roughly half the compute,
the FP8 speedup was far below expectation.

Generalize the existing 3D->2D reshape to any rank >= 3 (flatten the leading
dims, keep the contraction dim) and reshape the output back to the original
leading dims. 2D and 3D inputs are handled exactly as before; only rank >= 4
inputs change (now quantized instead of skipped). This matches the rank-agnostic
handling already used by the training path (flatten(0, -2) / unflatten).

Fixes #14595.

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
This commit is contained in:
liminfei-amd 2026-06-26 05:26:56 +00:00
parent 7cb784e0f4
commit f8af8dadaf

View File

@ -1235,7 +1235,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
run_every_op()
input_shape = input.shape
reshaped_3d = False
reshaped_nd = False
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype
@ -1272,12 +1272,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
# Inference path (unchanged)
if _use_quantized and quantize_input:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
# Reshape >=3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[-1]) if input.ndim >= 3 else input
# Fall back to non-quantized for non-2D tensors
if input_reshaped.ndim == 2:
reshaped_3d = input.ndim == 3
reshaped_nd = input.ndim >= 3
# dtype is now implicit in the layout class
scale = getattr(self, 'input_scale', None)
if scale is not None:
@ -1292,9 +1292,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight_only_quant=weight_only_quant,
)
# Reshape output back to 3D if input was 3D
if reshaped_3d:
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
# Reshape output back to original rank if input was >2D
if reshaped_nd:
output = output.reshape((*input_shape[:-1], self.weight.shape[0]))
return output