mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-17 07:05:12 +08:00
correct behavior
This commit is contained in:
parent
60f942e91b
commit
582ac60b29
@ -314,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
)
|
||||
total_loss += loss
|
||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||
total_loss.backward()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.scale(total_loss).backward()
|
||||
else:
|
||||
total_loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(total_loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||
@ -1019,7 +1022,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
"training_dtype",
|
||||
options=["bf16", "fp32", "none"],
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for training. 'none' disables autocast and uses the model's native dtype with GradScaler.",
|
||||
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"lora_dtype",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user