mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-28 20:43:32 +08:00
FP8 bwd training (#13121)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Build package / Build Test (3.11) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Build package / Build Test (3.12) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Build package / Build Test (3.13) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Waiting to run
Build package / Build Test (3.14) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Build package / Build Test (3.11) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Build package / Build Test (3.12) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Build package / Build Test (3.13) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Waiting to run
Build package / Build Test (3.14) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This commit is contained in:
parent
a0a64c679f
commit
5ebb0c2e0b
@ -55,6 +55,7 @@ total_vram = 0
|
|||||||
|
|
||||||
# Training Related State
|
# Training Related State
|
||||||
in_training = False
|
in_training = False
|
||||||
|
training_fp8_bwd = False
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
|
|||||||
65
comfy/ops.py
65
comfy/ops.py
@ -777,8 +777,16 @@ from .quant_ops import (
|
|||||||
|
|
||||||
|
|
||||||
class QuantLinearFunc(torch.autograd.Function):
|
class QuantLinearFunc(torch.autograd.Function):
|
||||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
|
||||||
|
When training_fp8_bwd is enabled:
|
||||||
|
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||||
|
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||||
|
- Cached input is FP8 (half the memory of bf16)
|
||||||
|
|
||||||
|
When training_fp8_bwd is disabled:
|
||||||
|
- Forward: quantize input per layout, use quantized matmul
|
||||||
|
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
input_shape = input_float.shape
|
input_shape = input_float.shape
|
||||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||||
|
|
||||||
# Quantize input (same as inference path)
|
# Quantize input for forward (same layout as weight)
|
||||||
if layout_type is not None:
|
if layout_type is not None:
|
||||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||||
else:
|
else:
|
||||||
@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
|
|
||||||
output = torch.nn.functional.linear(q_input, w, b)
|
output = torch.nn.functional.linear(q_input, w, b)
|
||||||
|
|
||||||
# Restore original input shape
|
# Unflatten output to match original input shape
|
||||||
if len(input_shape) > 2:
|
if len(input_shape) > 2:
|
||||||
output = output.unflatten(0, input_shape[:-1])
|
output = output.unflatten(0, input_shape[:-1])
|
||||||
|
|
||||||
ctx.save_for_backward(input_float, weight)
|
# Save for backward
|
||||||
ctx.input_shape = input_shape
|
ctx.input_shape = input_shape
|
||||||
ctx.has_bias = bias is not None
|
ctx.has_bias = bias is not None
|
||||||
ctx.compute_dtype = compute_dtype
|
ctx.compute_dtype = compute_dtype
|
||||||
ctx.weight_requires_grad = weight.requires_grad
|
ctx.weight_requires_grad = weight.requires_grad
|
||||||
|
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||||
|
|
||||||
|
if ctx.fp8_bwd:
|
||||||
|
# Cache FP8 quantized input — half the memory of bf16
|
||||||
|
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||||
|
ctx.q_input = q_input # already FP8, reuse
|
||||||
|
else:
|
||||||
|
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||||
|
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||||
|
ctx.save_for_backward(weight)
|
||||||
|
else:
|
||||||
|
ctx.q_input = None
|
||||||
|
ctx.save_for_backward(input_float, weight)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.autograd.function.once_differentiable
|
@torch.autograd.function.once_differentiable
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input_float, weight = ctx.saved_tensors
|
|
||||||
compute_dtype = ctx.compute_dtype
|
compute_dtype = ctx.compute_dtype
|
||||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||||
|
|
||||||
# Dequantize weight to compute dtype for backward matmul
|
# Value casting — only difference between fp8 and non-fp8 paths
|
||||||
if isinstance(weight, QuantizedTensor):
|
if ctx.fp8_bwd:
|
||||||
weight_f = weight.dequantize().to(compute_dtype)
|
weight, = ctx.saved_tensors
|
||||||
|
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||||
|
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||||
|
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||||
|
weight_mm = weight
|
||||||
|
elif isinstance(weight, QuantizedTensor):
|
||||||
|
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||||
|
else:
|
||||||
|
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||||
|
input_mm = ctx.q_input
|
||||||
else:
|
else:
|
||||||
weight_f = weight.to(compute_dtype)
|
input_float, weight = ctx.saved_tensors
|
||||||
|
# Standard tensors → torch.mm does regular matmul
|
||||||
|
grad_mm = grad_2d
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight_mm = weight.dequantize().to(compute_dtype)
|
||||||
|
else:
|
||||||
|
weight_mm = weight.to(compute_dtype)
|
||||||
|
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||||
|
|
||||||
# grad_input = grad_output @ weight
|
# Computation — same for both paths, dispatch handles the rest
|
||||||
grad_input = torch.mm(grad_2d, weight_f)
|
grad_input = torch.mm(grad_mm, weight_mm)
|
||||||
if len(ctx.input_shape) > 2:
|
if len(ctx.input_shape) > 2:
|
||||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||||
|
|
||||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
|
||||||
grad_weight = None
|
grad_weight = None
|
||||||
if ctx.weight_requires_grad:
|
if ctx.weight_requires_grad:
|
||||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
|
||||||
|
|
||||||
# grad_bias
|
|
||||||
grad_bias = None
|
grad_bias = None
|
||||||
if ctx.has_bias:
|
if ctx.has_bias:
|
||||||
grad_bias = grad_2d.sum(dim=0)
|
grad_bias = grad_2d.sum(dim=0)
|
||||||
|
|||||||
@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default="bf16",
|
default="bf16",
|
||||||
tooltip="The dtype to use for lora.",
|
tooltip="The dtype to use for lora.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"quantized_backward",
|
||||||
|
default=False,
|
||||||
|
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
||||||
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"algorithm",
|
"algorithm",
|
||||||
options=list(adapter_maps.keys()),
|
options=list(adapter_maps.keys()),
|
||||||
@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed,
|
seed,
|
||||||
training_dtype,
|
training_dtype,
|
||||||
lora_dtype,
|
lora_dtype,
|
||||||
|
quantized_backward,
|
||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
checkpoint_depth,
|
checkpoint_depth,
|
||||||
@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed = seed[0]
|
seed = seed[0]
|
||||||
training_dtype = training_dtype[0]
|
training_dtype = training_dtype[0]
|
||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
|
quantized_backward = quantized_backward[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
offloading = offloading[0]
|
offloading = offloading[0]
|
||||||
@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
bypass_mode = bypass_mode[0]
|
bypass_mode = bypass_mode[0]
|
||||||
|
|
||||||
|
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||||
|
|
||||||
# Process latents based on mode
|
# Process latents based on mode
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
latents = _process_latents_bucket_mode(latents)
|
latents = _process_latents_bucket_mode(latents)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user