diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 0296b810a..0ad0acee6 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -314,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler): ) total_loss += loss total_loss = total_loss / self.grad_acc / len(indicies) - total_loss.backward() + if self.grad_scaler is not None: + self.grad_scaler.scale(total_loss).backward() + else: + total_loss.backward() if self.loss_callback: self.loss_callback(total_loss.item()) pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) @@ -1019,7 +1022,7 @@ class TrainLoraNode(io.ComfyNode): "training_dtype", options=["bf16", "fp32", "none"], default="bf16", - tooltip="The dtype to use for training. 'none' disables autocast and uses the model's native dtype with GradScaler.", + tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.", ), io.Combo.Input( "lora_dtype",