diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index df1b39fd5..0616dfc2d 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1146,6 +1146,7 @@ class TrainLoraNode(io.ComfyNode): # Setup model and dtype mp = model.clone() use_grad_scaler = False + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) if training_dtype != "none": dtype = node_helpers.string_to_torch_dtype(training_dtype) mp.set_model_compute_dtype(dtype) @@ -1154,7 +1155,10 @@ class TrainLoraNode(io.ComfyNode): model_dtype = mp.model.get_dtype() if model_dtype == torch.float16: dtype = torch.float16 - use_grad_scaler = True + # GradScaler only supports float16 gradients, not bfloat16. + # Only enable it when lora params will also be in float16. + if lora_dtype != torch.bfloat16: + use_grad_scaler = True # Warn about fp16 accumulation instability during training if PerformanceFeature.Fp16Accumulation in args.fast: logging.warning( @@ -1165,7 +1169,6 @@ class TrainLoraNode(io.ComfyNode): else: # For fp8, bf16, or other dtypes, use bf16 autocast dtype = torch.bfloat16 - lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) # Prepare latents and compute counts latents_dtype = dtype if dtype not in (None,) else torch.bfloat16