mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-22 09:33:29 +08:00
correct behavior
This commit is contained in:
parent
60f942e91b
commit
582ac60b29
@ -314,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
)
|
)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||||
total_loss.backward()
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.scale(total_loss).backward()
|
||||||
|
else:
|
||||||
|
total_loss.backward()
|
||||||
if self.loss_callback:
|
if self.loss_callback:
|
||||||
self.loss_callback(total_loss.item())
|
self.loss_callback(total_loss.item())
|
||||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||||
@ -1019,7 +1022,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
"training_dtype",
|
"training_dtype",
|
||||||
options=["bf16", "fp32", "none"],
|
options=["bf16", "fp32", "none"],
|
||||||
default="bf16",
|
default="bf16",
|
||||||
tooltip="The dtype to use for training. 'none' disables autocast and uses the model's native dtype with GradScaler.",
|
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
|
||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"lora_dtype",
|
"lora_dtype",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user