mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-22 17:43:33 +08:00
Merge 4cebbc50f7 into 0904cc3fe5
This commit is contained in:
commit
54cb87f070
101
comfy/ops.py
101
comfy/ops.py
@ -766,6 +766,71 @@ from .quant_ops import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinearFunc(torch.autograd.Function):
|
||||||
|
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||||
|
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
|
||||||
|
input_shape = input_float.shape
|
||||||
|
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||||
|
|
||||||
|
# Quantize input (same as inference path)
|
||||||
|
if layout_type is not None:
|
||||||
|
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||||
|
else:
|
||||||
|
q_input = inp
|
||||||
|
|
||||||
|
w = weight.detach() if weight.requires_grad else weight
|
||||||
|
b = bias.detach() if bias is not None and bias.requires_grad else bias
|
||||||
|
|
||||||
|
output = torch.nn.functional.linear(q_input, w, b)
|
||||||
|
|
||||||
|
# Restore original input shape
|
||||||
|
if len(input_shape) > 2:
|
||||||
|
output = output.unflatten(0, input_shape[:-1])
|
||||||
|
|
||||||
|
ctx.save_for_backward(input_float, weight)
|
||||||
|
ctx.input_shape = input_shape
|
||||||
|
ctx.has_bias = bias is not None
|
||||||
|
ctx.compute_dtype = compute_dtype
|
||||||
|
ctx.weight_requires_grad = weight.requires_grad
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.autograd.function.once_differentiable
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input_float, weight = ctx.saved_tensors
|
||||||
|
compute_dtype = ctx.compute_dtype
|
||||||
|
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||||
|
|
||||||
|
# Dequantize weight to compute dtype for backward matmul
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight_f = weight.dequantize().to(compute_dtype)
|
||||||
|
else:
|
||||||
|
weight_f = weight.to(compute_dtype)
|
||||||
|
|
||||||
|
# grad_input = grad_output @ weight
|
||||||
|
grad_input = torch.mm(grad_2d, weight_f)
|
||||||
|
if len(ctx.input_shape) > 2:
|
||||||
|
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||||
|
|
||||||
|
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||||
|
grad_weight = None
|
||||||
|
if ctx.weight_requires_grad:
|
||||||
|
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||||
|
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||||
|
|
||||||
|
# grad_bias
|
||||||
|
grad_bias = None
|
||||||
|
if ctx.has_bias:
|
||||||
|
grad_bias = grad_2d.sum(dim=0)
|
||||||
|
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||||
class MixedPrecisionOps(manual_cast):
|
class MixedPrecisionOps(manual_cast):
|
||||||
_quant_config = quant_config
|
_quant_config = quant_config
|
||||||
@ -960,10 +1025,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
#If cast needs to apply lora, it should be done in the compute dtype
|
#If cast needs to apply lora, it should be done in the compute dtype
|
||||||
compute_dtype = input.dtype
|
compute_dtype = input.dtype
|
||||||
|
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
_use_quantized = (
|
||||||
|
getattr(self, 'layout_type', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||||
|
if (input.requires_grad and _use_quantized):
|
||||||
|
|
||||||
|
weight, bias, offload_stream = cast_bias_weight(
|
||||||
|
self,
|
||||||
|
input,
|
||||||
|
offloadable=True,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
want_requant=True
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = getattr(self, 'input_scale', None)
|
||||||
|
if scale is not None:
|
||||||
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||||
|
|
||||||
|
output = QuantLinearFunc.apply(
|
||||||
|
input, weight, bias, self.layout_type, scale, compute_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Inference path (unchanged)
|
||||||
|
if _use_quantized:
|
||||||
|
|
||||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||||
@ -1011,7 +1103,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
for key, param in self._parameters.items():
|
for key, param in self._parameters.items():
|
||||||
if param is None:
|
if param is None:
|
||||||
continue
|
continue
|
||||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
p = fn(param)
|
||||||
|
if p.is_inference():
|
||||||
|
p = p.clone()
|
||||||
|
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||||
for key, buf in self._buffers.items():
|
for key, buf in self._buffers.items():
|
||||||
if buf is not None:
|
if buf is not None:
|
||||||
self._buffers[key] = fn(buf)
|
self._buffers[key] = fn(buf)
|
||||||
|
|||||||
@ -897,6 +897,10 @@ def set_attr(obj, attr, value):
|
|||||||
return prev
|
return prev
|
||||||
|
|
||||||
def set_attr_param(obj, attr, value):
|
def set_attr_param(obj, attr, value):
|
||||||
|
# Clone inference tensors (created under torch.inference_mode) since
|
||||||
|
# their version counter is frozen and nn.Parameter() cannot wrap them.
|
||||||
|
if value.is_inference():
|
||||||
|
value = value.clone()
|
||||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||||
|
|
||||||
def set_attr_buffer(obj, attr, value):
|
def set_attr_buffer(obj, attr, value):
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import comfy.sampler_helpers
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy_extras.nodes_custom_sampler
|
import comfy_extras.nodes_custom_sampler
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
training_dtype=torch.bfloat16,
|
training_dtype=torch.bfloat16,
|
||||||
real_dataset=None,
|
real_dataset=None,
|
||||||
bucket_latents=None,
|
bucket_latents=None,
|
||||||
|
use_grad_scaler=False,
|
||||||
):
|
):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.bucket_latents: list[torch.Tensor] | None = (
|
self.bucket_latents: list[torch.Tensor] | None = (
|
||||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||||
)
|
)
|
||||||
|
# GradScaler for fp16 training
|
||||||
|
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||||
# Precompute bucket offsets and weights for sampling
|
# Precompute bucket offsets and weights for sampling
|
||||||
if bucket_latents is not None:
|
if bucket_latents is not None:
|
||||||
self._init_bucket_data(bucket_latents)
|
self._init_bucket_data(bucket_latents)
|
||||||
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
batch_sigmas.requires_grad_(True),
|
batch_sigmas.requires_grad_(True),
|
||||||
**batch_extra_args,
|
**batch_extra_args,
|
||||||
)
|
)
|
||||||
loss = self.loss_fn(x0_pred, x0)
|
loss = self.loss_fn(x0_pred.float(), x0.float())
|
||||||
if bwd:
|
if bwd:
|
||||||
bwd_loss = loss / self.grad_acc
|
bwd_loss = loss / self.grad_acc
|
||||||
bwd_loss.backward()
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.scale(bwd_loss).backward()
|
||||||
|
else:
|
||||||
|
bwd_loss.backward()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||||
@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
)
|
)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||||
total_loss.backward()
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.scale(total_loss).backward()
|
||||||
|
else:
|
||||||
|
total_loss.backward()
|
||||||
if self.loss_callback:
|
if self.loss_callback:
|
||||||
self.loss_callback(total_loss.item())
|
self.loss_callback(total_loss.item())
|
||||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||||
@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||||
|
|
||||||
if (i + 1) % self.grad_acc == 0:
|
if (i + 1) % self.grad_acc == 0:
|
||||||
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.unscale_(self.optimizer)
|
||||||
for param_groups in self.optimizer.param_groups:
|
for param_groups in self.optimizer.param_groups:
|
||||||
for param in param_groups["params"]:
|
for param in param_groups["params"]:
|
||||||
if param.grad is None:
|
if param.grad is None:
|
||||||
continue
|
continue
|
||||||
param.grad.data = param.grad.data.to(param.data.dtype)
|
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||||
self.optimizer.step()
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.step(self.optimizer)
|
||||||
|
self.grad_scaler.update()
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
ui_pbar.update(1)
|
ui_pbar.update(1)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"training_dtype",
|
"training_dtype",
|
||||||
options=["bf16", "fp32"],
|
options=["bf16", "fp32", "none"],
|
||||||
default="bf16",
|
default="bf16",
|
||||||
tooltip="The dtype to use for training.",
|
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
|
||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"lora_dtype",
|
"lora_dtype",
|
||||||
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
io.Boolean.Input(
|
io.Boolean.Input(
|
||||||
"offloading",
|
"offloading",
|
||||||
default=False,
|
default=False,
|
||||||
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
tooltip="Offload model weights to CPU during training to save GPU memory.",
|
||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"existing_lora",
|
"existing_lora",
|
||||||
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Setup model and dtype
|
# Setup model and dtype
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
use_grad_scaler = False
|
||||||
|
if training_dtype != "none":
|
||||||
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
else:
|
||||||
|
# Detect model's native dtype for autocast
|
||||||
|
model_dtype = mp.model.get_dtype()
|
||||||
|
if model_dtype == torch.float16:
|
||||||
|
dtype = torch.float16
|
||||||
|
use_grad_scaler = True
|
||||||
|
# Warn about fp16 accumulation instability during training
|
||||||
|
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||||
|
logging.warning(
|
||||||
|
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||||
|
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||||
|
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||||
|
dtype = torch.bfloat16
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
|
||||||
|
|
||||||
if mp.is_dynamic():
|
|
||||||
if not bypass_mode:
|
|
||||||
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
|
||||||
bypass_mode = True
|
|
||||||
offloading = True
|
|
||||||
elif offloading:
|
|
||||||
if not bypass_mode:
|
|
||||||
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# Prepare latents and compute counts
|
||||||
|
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||||
latents, dtype, bucket_mode
|
latents, latents_dtype, bucket_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate and expand conditioning
|
# Validate and expand conditioning
|
||||||
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
bucket_latents=latents,
|
bucket_latents=latents,
|
||||||
|
use_grad_scaler=use_grad_scaler,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_sampler = TrainSampler(
|
train_sampler = TrainSampler(
|
||||||
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
real_dataset=latents if multi_res else None,
|
real_dataset=latents if multi_res else None,
|
||||||
|
use_grad_scaler=use_grad_scaler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
|
|||||||
io.Int.Input(
|
io.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user