From 5ebb0c2e0b72945c271a2fb4db749585aa32a13c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Mar 2026 08:39:04 +0800 Subject: [PATCH] FP8 bwd training (#13121) --- comfy/model_management.py | 1 + comfy/ops.py | 65 ++++++++++++++++++++++++++++--------- comfy_extras/nodes_train.py | 9 +++++ 3 files changed, 59 insertions(+), 16 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2c250dacc..9617d8388 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -55,6 +55,7 @@ total_vram = 0 # Training Related State in_training = False +training_fp8_bwd = False def get_supported_float8_types(): diff --git a/comfy/ops.py b/comfy/ops.py index 1518ec9de..ca25693db 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -777,8 +777,16 @@ from .quant_ops import ( class QuantLinearFunc(torch.autograd.Function): - """Custom autograd function for quantized linear: quantized forward, compute_dtype backward. - Handles any input rank by flattening to 2D for matmul and restoring shape after. + """Custom autograd function for quantized linear: quantized forward, optionally FP8 backward. + + When training_fp8_bwd is enabled: + - Forward: quantize input per layout (FP8/NVFP4), use quantized matmul + - Backward: all matmuls use FP8 tensor cores via torch.mm dispatch + - Cached input is FP8 (half the memory of bf16) + + When training_fp8_bwd is disabled: + - Forward: quantize input per layout, use quantized matmul + - Backward: dequantize weight to compute_dtype, use standard matmul """ @staticmethod @@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function): input_shape = input_float.shape inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D - # Quantize input (same as inference path) + # Quantize input for forward (same layout as weight) if layout_type is not None: q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale) else: @@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function): output = torch.nn.functional.linear(q_input, w, b) - # Restore original input shape + # Unflatten output to match original input shape if len(input_shape) > 2: output = output.unflatten(0, input_shape[:-1]) - ctx.save_for_backward(input_float, weight) + # Save for backward ctx.input_shape = input_shape ctx.has_bias = bias is not None ctx.compute_dtype = compute_dtype ctx.weight_requires_grad = weight.requires_grad + ctx.fp8_bwd = comfy.model_management.training_fp8_bwd + + if ctx.fp8_bwd: + # Cache FP8 quantized input — half the memory of bf16 + if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'): + ctx.q_input = q_input # already FP8, reuse + else: + # NVFP4 or other layout — quantize input to FP8 for backward + ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout") + ctx.save_for_backward(weight) + else: + ctx.q_input = None + ctx.save_for_backward(input_float, weight) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, grad_output): - input_float, weight = ctx.saved_tensors compute_dtype = ctx.compute_dtype grad_2d = grad_output.flatten(0, -2).to(compute_dtype) - # Dequantize weight to compute dtype for backward matmul - if isinstance(weight, QuantizedTensor): - weight_f = weight.dequantize().to(compute_dtype) + # Value casting — only difference between fp8 and non-fp8 paths + if ctx.fp8_bwd: + weight, = ctx.saved_tensors + # Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm + grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout") + if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"): + weight_mm = weight + elif isinstance(weight, QuantizedTensor): + weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout") + else: + weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout") + input_mm = ctx.q_input else: - weight_f = weight.to(compute_dtype) + input_float, weight = ctx.saved_tensors + # Standard tensors → torch.mm does regular matmul + grad_mm = grad_2d + if isinstance(weight, QuantizedTensor): + weight_mm = weight.dequantize().to(compute_dtype) + else: + weight_mm = weight.to(compute_dtype) + input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None - # grad_input = grad_output @ weight - grad_input = torch.mm(grad_2d, weight_f) + # Computation — same for both paths, dispatch handles the rest + grad_input = torch.mm(grad_mm, weight_mm) if len(ctx.input_shape) > 2: grad_input = grad_input.unflatten(0, ctx.input_shape[:-1]) - # grad_weight (only if weight requires grad, typically frozen for quantized training) grad_weight = None if ctx.weight_requires_grad: - input_f = input_float.flatten(0, -2).to(compute_dtype) - grad_weight = torch.mm(grad_2d.t(), input_f) + grad_weight = torch.mm(grad_mm.t(), input_mm) - # grad_bias grad_bias = None if ctx.has_bias: grad_bias = grad_2d.sum(dim=0) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 0ad0acee6..df1b39fd5 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode): default="bf16", tooltip="The dtype to use for lora.", ), + io.Boolean.Input( + "quantized_backward", + default=False, + tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.", + ), io.Combo.Input( "algorithm", options=list(adapter_maps.keys()), @@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode): seed, training_dtype, lora_dtype, + quantized_backward, algorithm, gradient_checkpointing, checkpoint_depth, @@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode): seed = seed[0] training_dtype = training_dtype[0] lora_dtype = lora_dtype[0] + quantized_backward = quantized_backward[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] offloading = offloading[0] @@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode): bucket_mode = bucket_mode[0] bypass_mode = bypass_mode[0] + comfy.model_management.training_fp8_bwd = quantized_backward + # Process latents based on mode if bucket_mode: latents = _process_latents_bucket_mode(latents)