correct behavior

This commit is contained in:
Kohaku-Blueleaf 2026-02-28 01:15:05 +08:00
parent 60f942e91b
commit 582ac60b29

View File

@ -314,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.grad_scaler is not None:
self.grad_scaler.scale(total_loss).backward()
else:
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
@ -1019,7 +1022,7 @@ class TrainLoraNode(io.ComfyNode):
"training_dtype",
options=["bf16", "fp32", "none"],
default="bf16",
tooltip="The dtype to use for training. 'none' disables autocast and uses the model's native dtype with GradScaler.",
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
),
io.Combo.Input(
"lora_dtype",