From b53b10ea61ef7fc54fbde7c1e7b7c36565bacf82 Mon Sep 17 00:00:00 2001 From: Krishna Chaitanya Date: Tue, 24 Mar 2026 20:53:44 -0700 Subject: [PATCH] Fix Train LoRA crash when training_dtype is "none" with bfloat16 LoRA weights (#13145) When training_dtype is set to "none" and the model's native dtype is float16, GradScaler was unconditionally enabled. However, GradScaler does not support bfloat16 gradients (only float16/float32), causing a NotImplementedError when lora_dtype is "bf16" (the default). Fix by only enabling GradScaler when LoRA parameters are not in bfloat16, since bfloat16 has the same exponent range as float32 and does not need gradient scaling to avoid underflow. Fixes #13124 --- comfy_extras/nodes_train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index df1b39fd5..0616dfc2d 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1146,6 +1146,7 @@ class TrainLoraNode(io.ComfyNode): # Setup model and dtype mp = model.clone() use_grad_scaler = False + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) if training_dtype != "none": dtype = node_helpers.string_to_torch_dtype(training_dtype) mp.set_model_compute_dtype(dtype) @@ -1154,7 +1155,10 @@ class TrainLoraNode(io.ComfyNode): model_dtype = mp.model.get_dtype() if model_dtype == torch.float16: dtype = torch.float16 - use_grad_scaler = True + # GradScaler only supports float16 gradients, not bfloat16. + # Only enable it when lora params will also be in float16. + if lora_dtype != torch.bfloat16: + use_grad_scaler = True # Warn about fp16 accumulation instability during training if PerformanceFeature.Fp16Accumulation in args.fast: logging.warning( @@ -1165,7 +1169,6 @@ class TrainLoraNode(io.ComfyNode): else: # For fp8, bf16, or other dtypes, use bf16 autocast dtype = torch.bfloat16 - lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) # Prepare latents and compute counts latents_dtype = dtype if dtype not in (None,) else torch.bfloat16