From 8427326f05c0648154f3a48826b03f88bfb17cc5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Feb 2026 00:40:56 +0800 Subject: [PATCH 01/11] Support native dtype training --- comfy_extras/nodes_train.py | 61 ++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index aa2d88673..4b19658d4 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -15,6 +15,7 @@ import comfy.sampler_helpers import comfy.sd import comfy.utils import comfy.model_management +from comfy.cli_args import args, PerformanceFeature import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers @@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler): training_dtype=torch.bfloat16, real_dataset=None, bucket_latents=None, + use_grad_scaler=False, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler): self.bucket_latents: list[torch.Tensor] | None = ( bucket_latents # list of (Bi, C, Hi, Wi) ) + # GradScaler for fp16 training + self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None # Precompute bucket offsets and weights for sampling if bucket_latents is not None: self._init_bucket_data(bucket_latents) @@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler): batch_sigmas.requires_grad_(True), **batch_extra_args, ) - loss = self.loss_fn(x0_pred, x0) + loss = self.loss_fn(x0_pred.float(), x0.float()) if bwd: bwd_loss = loss / self.grad_acc - bwd_loss.backward() + if self.grad_scaler is not None: + self.grad_scaler.scale(bwd_loss).backward() + else: + bwd_loss.backward() return loss def _generate_batch_sigmas(self, model_wrap, batch_size, device): @@ -348,12 +355,18 @@ class TrainSampler(comfy.samplers.Sampler): self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: + if self.grad_scaler is not None: + self.grad_scaler.unscale_(self.optimizer) for param_groups in self.optimizer.param_groups: for param in param_groups["params"]: if param.grad is None: continue param.grad.data = param.grad.data.to(param.data.dtype) - self.optimizer.step() + if self.grad_scaler is not None: + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + self.optimizer.step() self.optimizer.zero_grad() ui_pbar.update(1) torch.cuda.empty_cache() @@ -1004,9 +1017,9 @@ class TrainLoraNode(io.ComfyNode): ), io.Combo.Input( "training_dtype", - options=["bf16", "fp32"], + options=["bf16", "fp32", "none"], default="bf16", - tooltip="The dtype to use for training.", + tooltip="The dtype to use for training. 'none' disables autocast and uses the model's native dtype with GradScaler.", ), io.Combo.Input( "lora_dtype", @@ -1035,7 +1048,7 @@ class TrainLoraNode(io.ComfyNode): io.Boolean.Input( "offloading", default=False, - tooltip="Offload the Model to RAM. Requires Bypass Mode.", + tooltip="Depth level for gradient checkpointing.", ), io.Combo.Input( "existing_lora", @@ -1120,22 +1133,32 @@ class TrainLoraNode(io.ComfyNode): # Setup model and dtype mp = model.clone() - dtype = node_helpers.string_to_torch_dtype(training_dtype) + use_grad_scaler = False + if training_dtype != "none": + dtype = node_helpers.string_to_torch_dtype(training_dtype) + mp.set_model_compute_dtype(dtype) + else: + # Detect model's native dtype for autocast + model_dtype = mp.model.get_dtype() + if model_dtype == torch.float16: + dtype = torch.float16 + use_grad_scaler = True + # Warn about fp16 accumulation instability during training + if PerformanceFeature.Fp16Accumulation in args.fast: + logging.warning( + "WARNING: FP16 model detected with fp16_accumulation enabled. " + "This combination can be numerically unstable during training and may cause NaN values. " + "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)." + ) + else: + # For fp8, bf16, or other dtypes, use bf16 autocast + dtype = torch.bfloat16 lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) - mp.set_model_compute_dtype(dtype) - - if mp.is_dynamic(): - if not bypass_mode: - logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode") - bypass_mode = True - offloading = True - elif offloading: - if not bypass_mode: - logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message") # Prepare latents and compute counts + latents_dtype = dtype if dtype not in (None,) else torch.bfloat16 latents, num_images, multi_res = _prepare_latents_and_count( - latents, dtype, bucket_mode + latents, latents_dtype, bucket_mode ) # Validate and expand conditioning @@ -1201,6 +1224,7 @@ class TrainLoraNode(io.ComfyNode): seed=seed, training_dtype=dtype, bucket_latents=latents, + use_grad_scaler=use_grad_scaler, ) else: train_sampler = TrainSampler( @@ -1213,6 +1237,7 @@ class TrainLoraNode(io.ComfyNode): seed=seed, training_dtype=dtype, real_dataset=latents if multi_res else None, + use_grad_scaler=use_grad_scaler, ) # Setup guider From 6e2a2ee34286bedcad6aaf4bf50cdd51735fadef Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Feb 2026 00:41:10 +0800 Subject: [PATCH 02/11] Support quant linear fwdbwd --- comfy/ops.py | 108 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 3 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 98fec1e1d..23e6f88c3 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -690,6 +690,73 @@ from .quant_ops import ( ) +class QuantLinearFunc(torch.autograd.Function): + """Custom autograd function for FP8/FP4 linear: FP8/FP4 forward, compute_dtype backward. + """ + + @staticmethod + def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype): + # Save for backward + ctx.save_for_backward(input_float, weight) + ctx.has_bias = bias is not None + ctx.compute_dtype = compute_dtype + ctx.weight_requires_grad = weight.requires_grad + + # Detach: QuantizedTensor.from_float and the patched F.linear + # do not support tensors with requires_grad + inp = input_float.detach() + if inp.ndim >= 3: + inp = inp.reshape(-1, inp.shape[-1]) + + # Quantize input (same as inference path) + if layout_type is not None and inp.ndim == 2: + q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale) + else: + q_input = inp + + w = weight.detach() if weight.requires_grad else weight + b = bias.detach() if bias is not None and bias.requires_grad else bias + + return torch.nn.functional.linear(q_input, w, b) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output): + input_float, weight = ctx.saved_tensors + compute_dtype = ctx.compute_dtype + + # Dequantize weight to compute dtype for backward matmul + if isinstance(weight, QuantizedTensor): + weight_f = weight.dequantize().to(compute_dtype) + else: + weight_f = weight.to(compute_dtype) + + # Cast grad_output to compute dtype (handles non-standard dtypes like fp8) + grad_output_f = grad_output.to(compute_dtype) + + # grad_input = grad_output @ weight + grad_input = grad_output_f.matmul(weight_f) + + # Reshape to match original input shape (e.g. 3D input was flattened to 2D in forward) + if grad_input.shape != input_float.shape: + grad_input = grad_input.reshape(input_float.shape) + + # grad_weight (only if weight requires grad, typically frozen for quantized training) + grad_weight = None + if ctx.weight_requires_grad: + input_f = input_float.to(compute_dtype) + if input_f.ndim >= 3: + input_f = input_f.reshape(-1, input_f.shape[-1]) + grad_weight = grad_output_f.t().matmul(input_f) + + # grad_bias + grad_bias = None + if ctx.has_bias: + grad_bias = grad_output_f.sum(dim=list(range(grad_output_f.ndim - 1))) + + return grad_input, grad_weight, grad_bias, None, None, None + + def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config @@ -868,10 +935,42 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec #If cast needs to apply lora, it should be done in the compute dtype compute_dtype = input.dtype - if (getattr(self, 'layout_type', None) is not None and + _use_quantized = ( + getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and not getattr(self, 'comfy_force_cast_weights', False) and - len(self.weight_function) == 0 and len(self.bias_function) == 0): + len(self.weight_function) == 0 and len(self.bias_function) == 0 + ) + + # Training path: FP8 forward with compute_dtype backward via autograd function + # Only for FP8 layouts (not NVFP4 which packs 2 elements per byte) + if (input.requires_grad and _use_quantized and + getattr(self, 'layout_type', '').startswith('TensorCoreFP8')): + + weight, bias, offload_stream = cast_bias_weight( + self, + input, + offloadable=True, + compute_dtype=compute_dtype, + want_requant=True + ) + + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + + output = QuantLinearFunc.apply( + input, weight, bias, self.layout_type, scale, compute_dtype + ) + + if input.ndim == 3: + output = output.reshape(input_shape[0], input_shape[1], self.weight.shape[0]) + + uncast_bias_weight(self, weight, bias, offload_stream) + return output + + # Inference path (unchanged) + if _use_quantized: # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input @@ -919,7 +1018,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec for key, param in self._parameters.items(): if param is None: continue - self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + p = fn(param) + if p.is_inference(): + p = p.clone() + self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) for key, buf in self._buffers.items(): if buf is not None: self._buffers[key] = fn(buf) From 3690e8134fe145dd5dae143d2f1d238e28a20566 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Feb 2026 00:41:20 +0800 Subject: [PATCH 03/11] Avoid inference/train tensor issue --- comfy/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 0769cef44..f77acbdda 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -881,6 +881,10 @@ def set_attr(obj, attr, value): return prev def set_attr_param(obj, attr, value): + # Clone inference tensors (created under torch.inference_mode) since + # their version counter is frozen and nn.Parameter() cannot wrap them. + if value.is_inference(): + value = value.clone() return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def copy_to_param(obj, attr, value): From 60f942e91bd354b00335b84ad6eba8193094f50c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Feb 2026 01:06:52 +0800 Subject: [PATCH 04/11] fix tooltip --- comfy_extras/nodes_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 4b19658d4..0296b810a 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1048,7 +1048,7 @@ class TrainLoraNode(io.ComfyNode): io.Boolean.Input( "offloading", default=False, - tooltip="Depth level for gradient checkpointing.", + tooltip="Offload model weights to CPU during training to save GPU memory.", ), io.Combo.Input( "existing_lora", @@ -1362,7 +1362,7 @@ class SaveLoRA(io.ComfyNode): io.Int.Input( "steps", optional=True, - tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.", ), ], outputs=[], From 582ac60b29d4eb91c67b1fa9b6c4913e79851669 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Feb 2026 01:15:05 +0800 Subject: [PATCH 05/11] correct behavior --- comfy_extras/nodes_train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 0296b810a..0ad0acee6 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -314,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler): ) total_loss += loss total_loss = total_loss / self.grad_acc / len(indicies) - total_loss.backward() + if self.grad_scaler is not None: + self.grad_scaler.scale(total_loss).backward() + else: + total_loss.backward() if self.loss_callback: self.loss_callback(total_loss.item()) pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) @@ -1019,7 +1022,7 @@ class TrainLoraNode(io.ComfyNode): "training_dtype", options=["bf16", "fp32", "none"], default="bf16", - tooltip="The dtype to use for training. 'none' disables autocast and uses the model's native dtype with GradScaler.", + tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.", ), io.Combo.Input( "lora_dtype", From 2e94badbe0621388ed805516df617454d67b71f9 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Feb 2026 00:40:56 +0800 Subject: [PATCH 06/11] Support native dtype training --- comfy_extras/nodes_train.py | 61 ++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index aa2d88673..4b19658d4 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -15,6 +15,7 @@ import comfy.sampler_helpers import comfy.sd import comfy.utils import comfy.model_management +from comfy.cli_args import args, PerformanceFeature import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers @@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler): training_dtype=torch.bfloat16, real_dataset=None, bucket_latents=None, + use_grad_scaler=False, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler): self.bucket_latents: list[torch.Tensor] | None = ( bucket_latents # list of (Bi, C, Hi, Wi) ) + # GradScaler for fp16 training + self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None # Precompute bucket offsets and weights for sampling if bucket_latents is not None: self._init_bucket_data(bucket_latents) @@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler): batch_sigmas.requires_grad_(True), **batch_extra_args, ) - loss = self.loss_fn(x0_pred, x0) + loss = self.loss_fn(x0_pred.float(), x0.float()) if bwd: bwd_loss = loss / self.grad_acc - bwd_loss.backward() + if self.grad_scaler is not None: + self.grad_scaler.scale(bwd_loss).backward() + else: + bwd_loss.backward() return loss def _generate_batch_sigmas(self, model_wrap, batch_size, device): @@ -348,12 +355,18 @@ class TrainSampler(comfy.samplers.Sampler): self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: + if self.grad_scaler is not None: + self.grad_scaler.unscale_(self.optimizer) for param_groups in self.optimizer.param_groups: for param in param_groups["params"]: if param.grad is None: continue param.grad.data = param.grad.data.to(param.data.dtype) - self.optimizer.step() + if self.grad_scaler is not None: + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + self.optimizer.step() self.optimizer.zero_grad() ui_pbar.update(1) torch.cuda.empty_cache() @@ -1004,9 +1017,9 @@ class TrainLoraNode(io.ComfyNode): ), io.Combo.Input( "training_dtype", - options=["bf16", "fp32"], + options=["bf16", "fp32", "none"], default="bf16", - tooltip="The dtype to use for training.", + tooltip="The dtype to use for training. 'none' disables autocast and uses the model's native dtype with GradScaler.", ), io.Combo.Input( "lora_dtype", @@ -1035,7 +1048,7 @@ class TrainLoraNode(io.ComfyNode): io.Boolean.Input( "offloading", default=False, - tooltip="Offload the Model to RAM. Requires Bypass Mode.", + tooltip="Depth level for gradient checkpointing.", ), io.Combo.Input( "existing_lora", @@ -1120,22 +1133,32 @@ class TrainLoraNode(io.ComfyNode): # Setup model and dtype mp = model.clone() - dtype = node_helpers.string_to_torch_dtype(training_dtype) + use_grad_scaler = False + if training_dtype != "none": + dtype = node_helpers.string_to_torch_dtype(training_dtype) + mp.set_model_compute_dtype(dtype) + else: + # Detect model's native dtype for autocast + model_dtype = mp.model.get_dtype() + if model_dtype == torch.float16: + dtype = torch.float16 + use_grad_scaler = True + # Warn about fp16 accumulation instability during training + if PerformanceFeature.Fp16Accumulation in args.fast: + logging.warning( + "WARNING: FP16 model detected with fp16_accumulation enabled. " + "This combination can be numerically unstable during training and may cause NaN values. " + "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)." + ) + else: + # For fp8, bf16, or other dtypes, use bf16 autocast + dtype = torch.bfloat16 lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) - mp.set_model_compute_dtype(dtype) - - if mp.is_dynamic(): - if not bypass_mode: - logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode") - bypass_mode = True - offloading = True - elif offloading: - if not bypass_mode: - logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message") # Prepare latents and compute counts + latents_dtype = dtype if dtype not in (None,) else torch.bfloat16 latents, num_images, multi_res = _prepare_latents_and_count( - latents, dtype, bucket_mode + latents, latents_dtype, bucket_mode ) # Validate and expand conditioning @@ -1201,6 +1224,7 @@ class TrainLoraNode(io.ComfyNode): seed=seed, training_dtype=dtype, bucket_latents=latents, + use_grad_scaler=use_grad_scaler, ) else: train_sampler = TrainSampler( @@ -1213,6 +1237,7 @@ class TrainLoraNode(io.ComfyNode): seed=seed, training_dtype=dtype, real_dataset=latents if multi_res else None, + use_grad_scaler=use_grad_scaler, ) # Setup guider From eb33188c8e4a625b03de42ca1dfd3666402320ee Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Feb 2026 00:41:10 +0800 Subject: [PATCH 07/11] Support quant linear fwdbwd --- comfy/ops.py | 108 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 3 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..7ce45bc88 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -689,6 +689,73 @@ from .quant_ops import ( ) +class QuantLinearFunc(torch.autograd.Function): + """Custom autograd function for FP8/FP4 linear: FP8/FP4 forward, compute_dtype backward. + """ + + @staticmethod + def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype): + # Save for backward + ctx.save_for_backward(input_float, weight) + ctx.has_bias = bias is not None + ctx.compute_dtype = compute_dtype + ctx.weight_requires_grad = weight.requires_grad + + # Detach: QuantizedTensor.from_float and the patched F.linear + # do not support tensors with requires_grad + inp = input_float.detach() + if inp.ndim >= 3: + inp = inp.reshape(-1, inp.shape[-1]) + + # Quantize input (same as inference path) + if layout_type is not None and inp.ndim == 2: + q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale) + else: + q_input = inp + + w = weight.detach() if weight.requires_grad else weight + b = bias.detach() if bias is not None and bias.requires_grad else bias + + return torch.nn.functional.linear(q_input, w, b) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output): + input_float, weight = ctx.saved_tensors + compute_dtype = ctx.compute_dtype + + # Dequantize weight to compute dtype for backward matmul + if isinstance(weight, QuantizedTensor): + weight_f = weight.dequantize().to(compute_dtype) + else: + weight_f = weight.to(compute_dtype) + + # Cast grad_output to compute dtype (handles non-standard dtypes like fp8) + grad_output_f = grad_output.to(compute_dtype) + + # grad_input = grad_output @ weight + grad_input = grad_output_f.matmul(weight_f) + + # Reshape to match original input shape (e.g. 3D input was flattened to 2D in forward) + if grad_input.shape != input_float.shape: + grad_input = grad_input.reshape(input_float.shape) + + # grad_weight (only if weight requires grad, typically frozen for quantized training) + grad_weight = None + if ctx.weight_requires_grad: + input_f = input_float.to(compute_dtype) + if input_f.ndim >= 3: + input_f = input_f.reshape(-1, input_f.shape[-1]) + grad_weight = grad_output_f.t().matmul(input_f) + + # grad_bias + grad_bias = None + if ctx.has_bias: + grad_bias = grad_output_f.sum(dim=list(range(grad_output_f.ndim - 1))) + + return grad_input, grad_weight, grad_bias, None, None, None + + def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config @@ -867,10 +934,42 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec #If cast needs to apply lora, it should be done in the compute dtype compute_dtype = input.dtype - if (getattr(self, 'layout_type', None) is not None and + _use_quantized = ( + getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and not getattr(self, 'comfy_force_cast_weights', False) and - len(self.weight_function) == 0 and len(self.bias_function) == 0): + len(self.weight_function) == 0 and len(self.bias_function) == 0 + ) + + # Training path: FP8 forward with compute_dtype backward via autograd function + # Only for FP8 layouts (not NVFP4 which packs 2 elements per byte) + if (input.requires_grad and _use_quantized and + getattr(self, 'layout_type', '').startswith('TensorCoreFP8')): + + weight, bias, offload_stream = cast_bias_weight( + self, + input, + offloadable=True, + compute_dtype=compute_dtype, + want_requant=True + ) + + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + + output = QuantLinearFunc.apply( + input, weight, bias, self.layout_type, scale, compute_dtype + ) + + if input.ndim == 3: + output = output.reshape(input_shape[0], input_shape[1], self.weight.shape[0]) + + uncast_bias_weight(self, weight, bias, offload_stream) + return output + + # Inference path (unchanged) + if _use_quantized: # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input @@ -918,7 +1017,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec for key, param in self._parameters.items(): if param is None: continue - self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + p = fn(param) + if p.is_inference(): + p = p.clone() + self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) for key, buf in self._buffers.items(): if buf is not None: self._buffers[key] = fn(buf) From 3e433cd02d7155881dd0a2f7213a45a437eae02c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Feb 2026 00:41:20 +0800 Subject: [PATCH 08/11] Avoid inference/train tensor issue --- comfy/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 0769cef44..f77acbdda 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -881,6 +881,10 @@ def set_attr(obj, attr, value): return prev def set_attr_param(obj, attr, value): + # Clone inference tensors (created under torch.inference_mode) since + # their version counter is frozen and nn.Parameter() cannot wrap them. + if value.is_inference(): + value = value.clone() return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def copy_to_param(obj, attr, value): From 0d7e529d7880f48f7d56ee3adf2a0d01959e54b5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Feb 2026 01:06:52 +0800 Subject: [PATCH 09/11] fix tooltip --- comfy_extras/nodes_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 4b19658d4..0296b810a 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1048,7 +1048,7 @@ class TrainLoraNode(io.ComfyNode): io.Boolean.Input( "offloading", default=False, - tooltip="Depth level for gradient checkpointing.", + tooltip="Offload model weights to CPU during training to save GPU memory.", ), io.Combo.Input( "existing_lora", @@ -1362,7 +1362,7 @@ class SaveLoRA(io.ComfyNode): io.Int.Input( "steps", optional=True, - tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.", ), ], outputs=[], From e82d7786fedcc9fa67b95c8919050a39323ed13b Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Feb 2026 01:15:05 +0800 Subject: [PATCH 10/11] correct behavior --- comfy_extras/nodes_train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 0296b810a..0ad0acee6 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -314,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler): ) total_loss += loss total_loss = total_loss / self.grad_acc / len(indicies) - total_loss.backward() + if self.grad_scaler is not None: + self.grad_scaler.scale(total_loss).backward() + else: + total_loss.backward() if self.loss_callback: self.loss_callback(total_loss.item()) pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) @@ -1019,7 +1022,7 @@ class TrainLoraNode(io.ComfyNode): "training_dtype", options=["bf16", "fp32", "none"], default="bf16", - tooltip="The dtype to use for training. 'none' disables autocast and uses the model's native dtype with GradScaler.", + tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.", ), io.Combo.Input( "lora_dtype", From 4cebbc50f7fc22f705a55b79d704198e60a4ec35 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:05:50 +0800 Subject: [PATCH 11/11] Full fix on bad shape handling We also ensured comments are matching the logic --- comfy/ops.py | 61 +++++++++++++++++++++++----------------------------- 1 file changed, 27 insertions(+), 34 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 7ce45bc88..df752389b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -690,25 +690,17 @@ from .quant_ops import ( class QuantLinearFunc(torch.autograd.Function): - """Custom autograd function for FP8/FP4 linear: FP8/FP4 forward, compute_dtype backward. + """Custom autograd function for quantized linear: quantized forward, compute_dtype backward. + Handles any input rank by flattening to 2D for matmul and restoring shape after. """ @staticmethod def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype): - # Save for backward - ctx.save_for_backward(input_float, weight) - ctx.has_bias = bias is not None - ctx.compute_dtype = compute_dtype - ctx.weight_requires_grad = weight.requires_grad - - # Detach: QuantizedTensor.from_float and the patched F.linear - # do not support tensors with requires_grad - inp = input_float.detach() - if inp.ndim >= 3: - inp = inp.reshape(-1, inp.shape[-1]) + input_shape = input_float.shape + inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D # Quantize input (same as inference path) - if layout_type is not None and inp.ndim == 2: + if layout_type is not None: q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale) else: q_input = inp @@ -716,13 +708,26 @@ class QuantLinearFunc(torch.autograd.Function): w = weight.detach() if weight.requires_grad else weight b = bias.detach() if bias is not None and bias.requires_grad else bias - return torch.nn.functional.linear(q_input, w, b) + output = torch.nn.functional.linear(q_input, w, b) + + # Restore original input shape + if len(input_shape) > 2: + output = output.unflatten(0, input_shape[:-1]) + + ctx.save_for_backward(input_float, weight) + ctx.input_shape = input_shape + ctx.has_bias = bias is not None + ctx.compute_dtype = compute_dtype + ctx.weight_requires_grad = weight.requires_grad + + return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, grad_output): input_float, weight = ctx.saved_tensors compute_dtype = ctx.compute_dtype + grad_2d = grad_output.flatten(0, -2).to(compute_dtype) # Dequantize weight to compute dtype for backward matmul if isinstance(weight, QuantizedTensor): @@ -730,28 +735,21 @@ class QuantLinearFunc(torch.autograd.Function): else: weight_f = weight.to(compute_dtype) - # Cast grad_output to compute dtype (handles non-standard dtypes like fp8) - grad_output_f = grad_output.to(compute_dtype) - # grad_input = grad_output @ weight - grad_input = grad_output_f.matmul(weight_f) - - # Reshape to match original input shape (e.g. 3D input was flattened to 2D in forward) - if grad_input.shape != input_float.shape: - grad_input = grad_input.reshape(input_float.shape) + grad_input = torch.mm(grad_2d, weight_f) + if len(ctx.input_shape) > 2: + grad_input = grad_input.unflatten(0, ctx.input_shape[:-1]) # grad_weight (only if weight requires grad, typically frozen for quantized training) grad_weight = None if ctx.weight_requires_grad: - input_f = input_float.to(compute_dtype) - if input_f.ndim >= 3: - input_f = input_f.reshape(-1, input_f.shape[-1]) - grad_weight = grad_output_f.t().matmul(input_f) + input_f = input_float.flatten(0, -2).to(compute_dtype) + grad_weight = torch.mm(grad_2d.t(), input_f) # grad_bias grad_bias = None if ctx.has_bias: - grad_bias = grad_output_f.sum(dim=list(range(grad_output_f.ndim - 1))) + grad_bias = grad_2d.sum(dim=0) return grad_input, grad_weight, grad_bias, None, None, None @@ -941,10 +939,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec len(self.weight_function) == 0 and len(self.bias_function) == 0 ) - # Training path: FP8 forward with compute_dtype backward via autograd function - # Only for FP8 layouts (not NVFP4 which packs 2 elements per byte) - if (input.requires_grad and _use_quantized and - getattr(self, 'layout_type', '').startswith('TensorCoreFP8')): + # Training path: quantized forward with compute_dtype backward via autograd function + if (input.requires_grad and _use_quantized): weight, bias, offload_stream = cast_bias_weight( self, @@ -962,9 +958,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec input, weight, bias, self.layout_type, scale, compute_dtype ) - if input.ndim == 3: - output = output.reshape(input_shape[0], input_shape[1], self.weight.shape[0]) - uncast_bias_weight(self, weight, bias, offload_stream) return output