Fix Train LoRA crash when training_dtype is "none" with bfloat16 LoRA weights (#13145)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

When training_dtype is set to "none" and the model's native dtype is
float16, GradScaler was unconditionally enabled. However, GradScaler
does not support bfloat16 gradients (only float16/float32), causing a
NotImplementedError when lora_dtype is "bf16" (the default).

Fix by only enabling GradScaler when LoRA parameters are not in
bfloat16, since bfloat16 has the same exponent range as float32 and
does not need gradient scaling to avoid underflow.

Fixes #13124
This commit is contained in:
Krishna Chaitanya 2026-03-24 20:53:44 -07:00 committed by GitHub
parent 7d5534d8e5
commit b53b10ea61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1146,6 +1146,7 @@ class TrainLoraNode(io.ComfyNode):
# Setup model and dtype
mp = model.clone()
use_grad_scaler = False
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
@ -1154,7 +1155,10 @@ class TrainLoraNode(io.ComfyNode):
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
use_grad_scaler = True
# GradScaler only supports float16 gradients, not bfloat16.
# Only enable it when lora params will also be in float16.
if lora_dtype != torch.bfloat16:
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
@ -1165,7 +1169,6 @@ class TrainLoraNode(io.ComfyNode):
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16