diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..7ce45bc88 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -689,6 +689,73 @@ from .quant_ops import ( ) +class QuantLinearFunc(torch.autograd.Function): + """Custom autograd function for FP8/FP4 linear: FP8/FP4 forward, compute_dtype backward. + """ + + @staticmethod + def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype): + # Save for backward + ctx.save_for_backward(input_float, weight) + ctx.has_bias = bias is not None + ctx.compute_dtype = compute_dtype + ctx.weight_requires_grad = weight.requires_grad + + # Detach: QuantizedTensor.from_float and the patched F.linear + # do not support tensors with requires_grad + inp = input_float.detach() + if inp.ndim >= 3: + inp = inp.reshape(-1, inp.shape[-1]) + + # Quantize input (same as inference path) + if layout_type is not None and inp.ndim == 2: + q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale) + else: + q_input = inp + + w = weight.detach() if weight.requires_grad else weight + b = bias.detach() if bias is not None and bias.requires_grad else bias + + return torch.nn.functional.linear(q_input, w, b) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output): + input_float, weight = ctx.saved_tensors + compute_dtype = ctx.compute_dtype + + # Dequantize weight to compute dtype for backward matmul + if isinstance(weight, QuantizedTensor): + weight_f = weight.dequantize().to(compute_dtype) + else: + weight_f = weight.to(compute_dtype) + + # Cast grad_output to compute dtype (handles non-standard dtypes like fp8) + grad_output_f = grad_output.to(compute_dtype) + + # grad_input = grad_output @ weight + grad_input = grad_output_f.matmul(weight_f) + + # Reshape to match original input shape (e.g. 3D input was flattened to 2D in forward) + if grad_input.shape != input_float.shape: + grad_input = grad_input.reshape(input_float.shape) + + # grad_weight (only if weight requires grad, typically frozen for quantized training) + grad_weight = None + if ctx.weight_requires_grad: + input_f = input_float.to(compute_dtype) + if input_f.ndim >= 3: + input_f = input_f.reshape(-1, input_f.shape[-1]) + grad_weight = grad_output_f.t().matmul(input_f) + + # grad_bias + grad_bias = None + if ctx.has_bias: + grad_bias = grad_output_f.sum(dim=list(range(grad_output_f.ndim - 1))) + + return grad_input, grad_weight, grad_bias, None, None, None + + def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config @@ -867,10 +934,42 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec #If cast needs to apply lora, it should be done in the compute dtype compute_dtype = input.dtype - if (getattr(self, 'layout_type', None) is not None and + _use_quantized = ( + getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and not getattr(self, 'comfy_force_cast_weights', False) and - len(self.weight_function) == 0 and len(self.bias_function) == 0): + len(self.weight_function) == 0 and len(self.bias_function) == 0 + ) + + # Training path: FP8 forward with compute_dtype backward via autograd function + # Only for FP8 layouts (not NVFP4 which packs 2 elements per byte) + if (input.requires_grad and _use_quantized and + getattr(self, 'layout_type', '').startswith('TensorCoreFP8')): + + weight, bias, offload_stream = cast_bias_weight( + self, + input, + offloadable=True, + compute_dtype=compute_dtype, + want_requant=True + ) + + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + + output = QuantLinearFunc.apply( + input, weight, bias, self.layout_type, scale, compute_dtype + ) + + if input.ndim == 3: + output = output.reshape(input_shape[0], input_shape[1], self.weight.shape[0]) + + uncast_bias_weight(self, weight, bias, offload_stream) + return output + + # Inference path (unchanged) + if _use_quantized: # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input @@ -918,7 +1017,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec for key, param in self._parameters.items(): if param is None: continue - self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + p = fn(param) + if p.is_inference(): + p = p.clone() + self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) for key, buf in self._buffers.items(): if buf is not None: self._buffers[key] = fn(buf)