mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-21 09:03:37 +08:00
Support quant linear fwdbwd
This commit is contained in:
parent
2e94badbe0
commit
eb33188c8e
108
comfy/ops.py
108
comfy/ops.py
@ -689,6 +689,73 @@ from .quant_ops import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinearFunc(torch.autograd.Function):
|
||||||
|
"""Custom autograd function for FP8/FP4 linear: FP8/FP4 forward, compute_dtype backward.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
|
||||||
|
# Save for backward
|
||||||
|
ctx.save_for_backward(input_float, weight)
|
||||||
|
ctx.has_bias = bias is not None
|
||||||
|
ctx.compute_dtype = compute_dtype
|
||||||
|
ctx.weight_requires_grad = weight.requires_grad
|
||||||
|
|
||||||
|
# Detach: QuantizedTensor.from_float and the patched F.linear
|
||||||
|
# do not support tensors with requires_grad
|
||||||
|
inp = input_float.detach()
|
||||||
|
if inp.ndim >= 3:
|
||||||
|
inp = inp.reshape(-1, inp.shape[-1])
|
||||||
|
|
||||||
|
# Quantize input (same as inference path)
|
||||||
|
if layout_type is not None and inp.ndim == 2:
|
||||||
|
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||||
|
else:
|
||||||
|
q_input = inp
|
||||||
|
|
||||||
|
w = weight.detach() if weight.requires_grad else weight
|
||||||
|
b = bias.detach() if bias is not None and bias.requires_grad else bias
|
||||||
|
|
||||||
|
return torch.nn.functional.linear(q_input, w, b)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.autograd.function.once_differentiable
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input_float, weight = ctx.saved_tensors
|
||||||
|
compute_dtype = ctx.compute_dtype
|
||||||
|
|
||||||
|
# Dequantize weight to compute dtype for backward matmul
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight_f = weight.dequantize().to(compute_dtype)
|
||||||
|
else:
|
||||||
|
weight_f = weight.to(compute_dtype)
|
||||||
|
|
||||||
|
# Cast grad_output to compute dtype (handles non-standard dtypes like fp8)
|
||||||
|
grad_output_f = grad_output.to(compute_dtype)
|
||||||
|
|
||||||
|
# grad_input = grad_output @ weight
|
||||||
|
grad_input = grad_output_f.matmul(weight_f)
|
||||||
|
|
||||||
|
# Reshape to match original input shape (e.g. 3D input was flattened to 2D in forward)
|
||||||
|
if grad_input.shape != input_float.shape:
|
||||||
|
grad_input = grad_input.reshape(input_float.shape)
|
||||||
|
|
||||||
|
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||||
|
grad_weight = None
|
||||||
|
if ctx.weight_requires_grad:
|
||||||
|
input_f = input_float.to(compute_dtype)
|
||||||
|
if input_f.ndim >= 3:
|
||||||
|
input_f = input_f.reshape(-1, input_f.shape[-1])
|
||||||
|
grad_weight = grad_output_f.t().matmul(input_f)
|
||||||
|
|
||||||
|
# grad_bias
|
||||||
|
grad_bias = None
|
||||||
|
if ctx.has_bias:
|
||||||
|
grad_bias = grad_output_f.sum(dim=list(range(grad_output_f.ndim - 1)))
|
||||||
|
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||||
class MixedPrecisionOps(manual_cast):
|
class MixedPrecisionOps(manual_cast):
|
||||||
_quant_config = quant_config
|
_quant_config = quant_config
|
||||||
@ -867,10 +934,42 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
#If cast needs to apply lora, it should be done in the compute dtype
|
#If cast needs to apply lora, it should be done in the compute dtype
|
||||||
compute_dtype = input.dtype
|
compute_dtype = input.dtype
|
||||||
|
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
_use_quantized = (
|
||||||
|
getattr(self, 'layout_type', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training path: FP8 forward with compute_dtype backward via autograd function
|
||||||
|
# Only for FP8 layouts (not NVFP4 which packs 2 elements per byte)
|
||||||
|
if (input.requires_grad and _use_quantized and
|
||||||
|
getattr(self, 'layout_type', '').startswith('TensorCoreFP8')):
|
||||||
|
|
||||||
|
weight, bias, offload_stream = cast_bias_weight(
|
||||||
|
self,
|
||||||
|
input,
|
||||||
|
offloadable=True,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
want_requant=True
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = getattr(self, 'input_scale', None)
|
||||||
|
if scale is not None:
|
||||||
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||||
|
|
||||||
|
output = QuantLinearFunc.apply(
|
||||||
|
input, weight, bias, self.layout_type, scale, compute_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if input.ndim == 3:
|
||||||
|
output = output.reshape(input_shape[0], input_shape[1], self.weight.shape[0])
|
||||||
|
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Inference path (unchanged)
|
||||||
|
if _use_quantized:
|
||||||
|
|
||||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||||
@ -918,7 +1017,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
for key, param in self._parameters.items():
|
for key, param in self._parameters.items():
|
||||||
if param is None:
|
if param is None:
|
||||||
continue
|
continue
|
||||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
p = fn(param)
|
||||||
|
if p.is_inference():
|
||||||
|
p = p.clone()
|
||||||
|
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||||
for key, buf in self._buffers.items():
|
for key, buf in self._buffers.items():
|
||||||
if buf is not None:
|
if buf is not None:
|
||||||
self._buffers[key] = fn(buf)
|
self._buffers[key] = fn(buf)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user