Full fix on bad shape handling

We also ensured  comments are matching the logic
This commit is contained in:
Kohaku-Blueleaf 2026-03-02 20:05:50 +08:00
parent f14adb8282
commit 4cebbc50f7

View File

@ -690,25 +690,17 @@ from .quant_ops import (
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for FP8/FP4 linear: FP8/FP4 forward, compute_dtype backward.
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""
@staticmethod
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
# Save for backward
ctx.save_for_backward(input_float, weight)
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
# Detach: QuantizedTensor.from_float and the patched F.linear
# do not support tensors with requires_grad
inp = input_float.detach()
if inp.ndim >= 3:
inp = inp.reshape(-1, inp.shape[-1])
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
if layout_type is not None and inp.ndim == 2:
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
q_input = inp
@ -716,13 +708,26 @@ class QuantLinearFunc(torch.autograd.Function):
w = weight.detach() if weight.requires_grad else weight
b = bias.detach() if bias is not None and bias.requires_grad else bias
return torch.nn.functional.linear(q_input, w, b)
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
@ -730,28 +735,21 @@ class QuantLinearFunc(torch.autograd.Function):
else:
weight_f = weight.to(compute_dtype)
# Cast grad_output to compute dtype (handles non-standard dtypes like fp8)
grad_output_f = grad_output.to(compute_dtype)
# grad_input = grad_output @ weight
grad_input = grad_output_f.matmul(weight_f)
# Reshape to match original input shape (e.g. 3D input was flattened to 2D in forward)
if grad_input.shape != input_float.shape:
grad_input = grad_input.reshape(input_float.shape)
grad_input = torch.mm(grad_2d, weight_f)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.to(compute_dtype)
if input_f.ndim >= 3:
input_f = input_f.reshape(-1, input_f.shape[-1])
grad_weight = grad_output_f.t().matmul(input_f)
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_output_f.sum(dim=list(range(grad_output_f.ndim - 1)))
grad_bias = grad_2d.sum(dim=0)
return grad_input, grad_weight, grad_bias, None, None, None
@ -941,10 +939,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
len(self.weight_function) == 0 and len(self.bias_function) == 0
)
# Training path: FP8 forward with compute_dtype backward via autograd function
# Only for FP8 layouts (not NVFP4 which packs 2 elements per byte)
if (input.requires_grad and _use_quantized and
getattr(self, 'layout_type', '').startswith('TensorCoreFP8')):
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
weight, bias, offload_stream = cast_bias_weight(
self,
@ -962,9 +958,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
input, weight, bias, self.layout_type, scale, compute_dtype
)
if input.ndim == 3:
output = output.reshape(input_shape[0], input_shape[1], self.weight.shape[0])
uncast_bias_weight(self, weight, bias, offload_stream)
return output