mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-17 23:25:05 +08:00
Support native dtype training
This commit is contained in:
parent
afb54219fa
commit
2e94badbe0
@ -15,6 +15,7 @@ import comfy.sampler_helpers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
training_dtype=torch.bfloat16,
|
||||
real_dataset=None,
|
||||
bucket_latents=None,
|
||||
use_grad_scaler=False,
|
||||
):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.bucket_latents: list[torch.Tensor] | None = (
|
||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||
)
|
||||
# GradScaler for fp16 training
|
||||
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||
# Precompute bucket offsets and weights for sampling
|
||||
if bucket_latents is not None:
|
||||
self._init_bucket_data(bucket_latents)
|
||||
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
batch_sigmas.requires_grad_(True),
|
||||
**batch_extra_args,
|
||||
)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss = self.loss_fn(x0_pred.float(), x0.float())
|
||||
if bwd:
|
||||
bwd_loss = loss / self.grad_acc
|
||||
bwd_loss.backward()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.scale(bwd_loss).backward()
|
||||
else:
|
||||
bwd_loss.backward()
|
||||
return loss
|
||||
|
||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||
@ -348,12 +355,18 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.unscale_(self.optimizer)
|
||||
for param_groups in self.optimizer.param_groups:
|
||||
for param in param_groups["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||
self.optimizer.step()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
ui_pbar.update(1)
|
||||
torch.cuda.empty_cache()
|
||||
@ -1004,9 +1017,9 @@ class TrainLoraNode(io.ComfyNode):
|
||||
),
|
||||
io.Combo.Input(
|
||||
"training_dtype",
|
||||
options=["bf16", "fp32"],
|
||||
options=["bf16", "fp32", "none"],
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for training.",
|
||||
tooltip="The dtype to use for training. 'none' disables autocast and uses the model's native dtype with GradScaler.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"lora_dtype",
|
||||
@ -1035,7 +1048,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
io.Boolean.Input(
|
||||
"offloading",
|
||||
default=False,
|
||||
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
||||
tooltip="Depth level for gradient checkpointing.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"existing_lora",
|
||||
@ -1120,22 +1133,32 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
use_grad_scaler = False
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
if mp.is_dynamic():
|
||||
if not bypass_mode:
|
||||
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
||||
bypass_mode = True
|
||||
offloading = True
|
||||
elif offloading:
|
||||
if not bypass_mode:
|
||||
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, dtype, bucket_mode
|
||||
latents, latents_dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
@ -1201,6 +1224,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
bucket_latents=latents,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
)
|
||||
else:
|
||||
train_sampler = TrainSampler(
|
||||
@ -1213,6 +1237,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
)
|
||||
|
||||
# Setup guider
|
||||
|
||||
Loading…
Reference in New Issue
Block a user