From f8af8dadaf981229166595e57089e27a94e66c9c Mon Sep 17 00:00:00 2001 From: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com> Date: Fri, 26 Jun 2026 05:26:56 +0000 Subject: [PATCH] Fix FP8 activation quantization for >2D activations in mixed_precision_ops mixed_precision_ops.Linear.forward only quantized activations that were 2D, or 3D (reshaped to 2D). Inputs with rank >= 4 (e.g. Anima's MLP activations, which are not reshaped to 3D the way the attention path is) fell through the `input_reshaped.ndim == 2` guard and reached scaled_mm as bf16, silently dispatching a bf16 kernel instead of FP8. Since MLP is roughly half the compute, the FP8 speedup was far below expectation. Generalize the existing 3D->2D reshape to any rank >= 3 (flatten the leading dims, keep the contraction dim) and reshape the output back to the original leading dims. 2D and 3D inputs are handled exactly as before; only rank >= 4 inputs change (now quantized instead of skipped). This matches the rank-agnostic handling already used by the training path (flatten(0, -2) / unflatten). Fixes #14595. Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com> --- comfy/ops.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 634610f1c..b2326b13c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1235,7 +1235,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec run_every_op() input_shape = input.shape - reshaped_3d = False + reshaped_nd = False #If cast needs to apply lora, it should be done in the compute dtype compute_dtype = input.dtype @@ -1272,12 +1272,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec # Inference path (unchanged) if _use_quantized and quantize_input: - # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) - input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input + # Reshape >=3D tensors to 2D for quantization (needed for NVFP4 and others) + input_reshaped = input.reshape(-1, input_shape[-1]) if input.ndim >= 3 else input # Fall back to non-quantized for non-2D tensors if input_reshaped.ndim == 2: - reshaped_3d = input.ndim == 3 + reshaped_nd = input.ndim >= 3 # dtype is now implicit in the layout class scale = getattr(self, 'input_scale', None) if scale is not None: @@ -1292,9 +1292,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec weight_only_quant=weight_only_quant, ) - # Reshape output back to 3D if input was 3D - if reshaped_3d: - output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0])) + # Reshape output back to original rank if input was >2D + if reshaped_nd: + output = output.reshape((*input_shape[:-1], self.weight.shape[0])) return output