This commit is contained in:
Kohaku-Blueleaf 2026-03-15 08:19:20 +01:00 committed by GitHub
commit 54cb87f070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 150 additions and 23 deletions

View File

@ -766,6 +766,71 @@ from .quant_ops import (
)
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""
@staticmethod
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
q_input = inp
w = weight.detach() if weight.requires_grad else weight
b = bias.detach() if bias is not None and bias.requires_grad else bias
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
else:
weight_f = weight.to(compute_dtype)
# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)
return grad_input, grad_weight, grad_bias, None, None, None
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
@ -960,10 +1025,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype
if (getattr(self, 'layout_type', None) is not None and
_use_quantized = (
getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0):
len(self.weight_function) == 0 and len(self.bias_function) == 0
)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=True
)
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
output = QuantLinearFunc.apply(
input, weight, bias, self.layout_type, scale, compute_dtype
)
uncast_bias_weight(self, weight, bias, offload_stream)
return output
# Inference path (unchanged)
if _use_quantized:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@ -1011,7 +1103,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
for key, param in self._parameters.items():
if param is None:
continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
p = fn(param)
if p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)

View File

@ -897,6 +897,10 @@ def set_attr(obj, attr, value):
return prev
def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):

View File

@ -15,6 +15,7 @@ import comfy.sampler_helpers
import comfy.sd
import comfy.utils
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
use_grad_scaler=False,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
# GradScaler for fp16 training
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
batch_sigmas.requires_grad_(True),
**batch_extra_args,
)
loss = self.loss_fn(x0_pred, x0)
loss = self.loss_fn(x0_pred.float(), x0.float())
if bwd:
bwd_loss = loss / self.grad_acc
bwd_loss.backward()
if self.grad_scaler is not None:
self.grad_scaler.scale(bwd_loss).backward()
else:
bwd_loss.backward()
return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.grad_scaler is not None:
self.grad_scaler.scale(total_loss).backward()
else:
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
if self.grad_scaler is not None:
self.grad_scaler.unscale_(self.optimizer)
for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]:
if param.grad is None:
continue
param.grad.data = param.grad.data.to(param.data.dtype)
self.optimizer.step()
if self.grad_scaler is not None:
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
torch.cuda.empty_cache()
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
),
io.Combo.Input(
"training_dtype",
options=["bf16", "fp32"],
options=["bf16", "fp32", "none"],
default="bf16",
tooltip="The dtype to use for training.",
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
),
io.Combo.Input(
"lora_dtype",
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
io.Boolean.Input(
"offloading",
default=False,
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
tooltip="Offload model weights to CPU during training to save GPU memory.",
),
io.Combo.Input(
"existing_lora",
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
# Setup model and dtype
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
use_grad_scaler = False
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
if mp.is_dynamic():
if not bypass_mode:
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
bypass_mode = True
offloading = True
elif offloading:
if not bypass_mode:
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
use_grad_scaler=use_grad_scaler,
)
else:
train_sampler = TrainSampler(
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
use_grad_scaler=use_grad_scaler,
)
# Setup guider
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
io.Int.Input(
"steps",
optional=True,
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
),
],
outputs=[],