FP8 bwd training (#13121)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Build package / Build Test (3.11) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Build package / Build Test (3.12) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Build package / Build Test (3.13) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Waiting to run
Build package / Build Test (3.14) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This commit is contained in:
Kohaku-Blueleaf 2026-03-25 08:39:04 +08:00 committed by GitHub
parent a0a64c679f
commit 5ebb0c2e0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 16 deletions

View File

@ -55,6 +55,7 @@ total_vram = 0
# Training Related State
in_training = False
training_fp8_bwd = False
def get_supported_float8_types():

View File

@ -777,8 +777,16 @@ from .quant_ops import (
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
When training_fp8_bwd is enabled:
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
- Cached input is FP8 (half the memory of bf16)
When training_fp8_bwd is disabled:
- Forward: quantize input per layout, use quantized matmul
- Backward: dequantize weight to compute_dtype, use standard matmul
"""
@staticmethod
@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
# Quantize input for forward (same layout as weight)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
# Unflatten output to match original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
# Save for backward
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
if ctx.fp8_bwd:
# Cache FP8 quantized input — half the memory of bf16
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
ctx.q_input = q_input # already FP8, reuse
else:
# NVFP4 or other layout — quantize input to FP8 for backward
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
ctx.save_for_backward(weight)
else:
ctx.q_input = None
ctx.save_for_backward(input_float, weight)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
# Value casting — only difference between fp8 and non-fp8 paths
if ctx.fp8_bwd:
weight, = ctx.saved_tensors
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
weight_mm = weight
elif isinstance(weight, QuantizedTensor):
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
else:
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
input_mm = ctx.q_input
else:
weight_f = weight.to(compute_dtype)
input_float, weight = ctx.saved_tensors
# Standard tensors → torch.mm does regular matmul
grad_mm = grad_2d
if isinstance(weight, QuantizedTensor):
weight_mm = weight.dequantize().to(compute_dtype)
else:
weight_mm = weight.to(compute_dtype)
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
# Computation — same for both paths, dispatch handles the rest
grad_input = torch.mm(grad_mm, weight_mm)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
grad_weight = torch.mm(grad_mm.t(), input_mm)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)

View File

@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
default="bf16",
tooltip="The dtype to use for lora.",
),
io.Boolean.Input(
"quantized_backward",
default=False,
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
),
io.Combo.Input(
"algorithm",
options=list(adapter_maps.keys()),
@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
seed,
training_dtype,
lora_dtype,
quantized_backward,
algorithm,
gradient_checkpointing,
checkpoint_depth,
@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
seed = seed[0]
training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0]
quantized_backward = quantized_backward[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
comfy.model_management.training_fp8_bwd = quantized_backward
# Process latents based on mode
if bucket_mode:
latents = _process_latents_bucket_mode(latents)