diff --git a/comfy/ops.py b/comfy/ops.py index 7ce45bc88..df752389b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -690,25 +690,17 @@ from .quant_ops import ( class QuantLinearFunc(torch.autograd.Function): - """Custom autograd function for FP8/FP4 linear: FP8/FP4 forward, compute_dtype backward. + """Custom autograd function for quantized linear: quantized forward, compute_dtype backward. + Handles any input rank by flattening to 2D for matmul and restoring shape after. """ @staticmethod def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype): - # Save for backward - ctx.save_for_backward(input_float, weight) - ctx.has_bias = bias is not None - ctx.compute_dtype = compute_dtype - ctx.weight_requires_grad = weight.requires_grad - - # Detach: QuantizedTensor.from_float and the patched F.linear - # do not support tensors with requires_grad - inp = input_float.detach() - if inp.ndim >= 3: - inp = inp.reshape(-1, inp.shape[-1]) + input_shape = input_float.shape + inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D # Quantize input (same as inference path) - if layout_type is not None and inp.ndim == 2: + if layout_type is not None: q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale) else: q_input = inp @@ -716,13 +708,26 @@ class QuantLinearFunc(torch.autograd.Function): w = weight.detach() if weight.requires_grad else weight b = bias.detach() if bias is not None and bias.requires_grad else bias - return torch.nn.functional.linear(q_input, w, b) + output = torch.nn.functional.linear(q_input, w, b) + + # Restore original input shape + if len(input_shape) > 2: + output = output.unflatten(0, input_shape[:-1]) + + ctx.save_for_backward(input_float, weight) + ctx.input_shape = input_shape + ctx.has_bias = bias is not None + ctx.compute_dtype = compute_dtype + ctx.weight_requires_grad = weight.requires_grad + + return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, grad_output): input_float, weight = ctx.saved_tensors compute_dtype = ctx.compute_dtype + grad_2d = grad_output.flatten(0, -2).to(compute_dtype) # Dequantize weight to compute dtype for backward matmul if isinstance(weight, QuantizedTensor): @@ -730,28 +735,21 @@ class QuantLinearFunc(torch.autograd.Function): else: weight_f = weight.to(compute_dtype) - # Cast grad_output to compute dtype (handles non-standard dtypes like fp8) - grad_output_f = grad_output.to(compute_dtype) - # grad_input = grad_output @ weight - grad_input = grad_output_f.matmul(weight_f) - - # Reshape to match original input shape (e.g. 3D input was flattened to 2D in forward) - if grad_input.shape != input_float.shape: - grad_input = grad_input.reshape(input_float.shape) + grad_input = torch.mm(grad_2d, weight_f) + if len(ctx.input_shape) > 2: + grad_input = grad_input.unflatten(0, ctx.input_shape[:-1]) # grad_weight (only if weight requires grad, typically frozen for quantized training) grad_weight = None if ctx.weight_requires_grad: - input_f = input_float.to(compute_dtype) - if input_f.ndim >= 3: - input_f = input_f.reshape(-1, input_f.shape[-1]) - grad_weight = grad_output_f.t().matmul(input_f) + input_f = input_float.flatten(0, -2).to(compute_dtype) + grad_weight = torch.mm(grad_2d.t(), input_f) # grad_bias grad_bias = None if ctx.has_bias: - grad_bias = grad_output_f.sum(dim=list(range(grad_output_f.ndim - 1))) + grad_bias = grad_2d.sum(dim=0) return grad_input, grad_weight, grad_bias, None, None, None @@ -941,10 +939,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec len(self.weight_function) == 0 and len(self.bias_function) == 0 ) - # Training path: FP8 forward with compute_dtype backward via autograd function - # Only for FP8 layouts (not NVFP4 which packs 2 elements per byte) - if (input.requires_grad and _use_quantized and - getattr(self, 'layout_type', '').startswith('TensorCoreFP8')): + # Training path: quantized forward with compute_dtype backward via autograd function + if (input.requires_grad and _use_quantized): weight, bias, offload_stream = cast_bias_weight( self, @@ -962,9 +958,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec input, weight, bias, self.layout_type, scale, compute_dtype ) - if input.ndim == 3: - output = output.reshape(input_shape[0], input_shape[1], self.weight.shape[0]) - uncast_bias_weight(self, weight, bias, offload_stream) return output