diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index aa2d88673..4b19658d4 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -15,6 +15,7 @@ import comfy.sampler_helpers import comfy.sd import comfy.utils import comfy.model_management +from comfy.cli_args import args, PerformanceFeature import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers @@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler): training_dtype=torch.bfloat16, real_dataset=None, bucket_latents=None, + use_grad_scaler=False, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler): self.bucket_latents: list[torch.Tensor] | None = ( bucket_latents # list of (Bi, C, Hi, Wi) ) + # GradScaler for fp16 training + self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None # Precompute bucket offsets and weights for sampling if bucket_latents is not None: self._init_bucket_data(bucket_latents) @@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler): batch_sigmas.requires_grad_(True), **batch_extra_args, ) - loss = self.loss_fn(x0_pred, x0) + loss = self.loss_fn(x0_pred.float(), x0.float()) if bwd: bwd_loss = loss / self.grad_acc - bwd_loss.backward() + if self.grad_scaler is not None: + self.grad_scaler.scale(bwd_loss).backward() + else: + bwd_loss.backward() return loss def _generate_batch_sigmas(self, model_wrap, batch_size, device): @@ -348,12 +355,18 @@ class TrainSampler(comfy.samplers.Sampler): self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: + if self.grad_scaler is not None: + self.grad_scaler.unscale_(self.optimizer) for param_groups in self.optimizer.param_groups: for param in param_groups["params"]: if param.grad is None: continue param.grad.data = param.grad.data.to(param.data.dtype) - self.optimizer.step() + if self.grad_scaler is not None: + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + self.optimizer.step() self.optimizer.zero_grad() ui_pbar.update(1) torch.cuda.empty_cache() @@ -1004,9 +1017,9 @@ class TrainLoraNode(io.ComfyNode): ), io.Combo.Input( "training_dtype", - options=["bf16", "fp32"], + options=["bf16", "fp32", "none"], default="bf16", - tooltip="The dtype to use for training.", + tooltip="The dtype to use for training. 'none' disables autocast and uses the model's native dtype with GradScaler.", ), io.Combo.Input( "lora_dtype", @@ -1035,7 +1048,7 @@ class TrainLoraNode(io.ComfyNode): io.Boolean.Input( "offloading", default=False, - tooltip="Offload the Model to RAM. Requires Bypass Mode.", + tooltip="Depth level for gradient checkpointing.", ), io.Combo.Input( "existing_lora", @@ -1120,22 +1133,32 @@ class TrainLoraNode(io.ComfyNode): # Setup model and dtype mp = model.clone() - dtype = node_helpers.string_to_torch_dtype(training_dtype) + use_grad_scaler = False + if training_dtype != "none": + dtype = node_helpers.string_to_torch_dtype(training_dtype) + mp.set_model_compute_dtype(dtype) + else: + # Detect model's native dtype for autocast + model_dtype = mp.model.get_dtype() + if model_dtype == torch.float16: + dtype = torch.float16 + use_grad_scaler = True + # Warn about fp16 accumulation instability during training + if PerformanceFeature.Fp16Accumulation in args.fast: + logging.warning( + "WARNING: FP16 model detected with fp16_accumulation enabled. " + "This combination can be numerically unstable during training and may cause NaN values. " + "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)." + ) + else: + # For fp8, bf16, or other dtypes, use bf16 autocast + dtype = torch.bfloat16 lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) - mp.set_model_compute_dtype(dtype) - - if mp.is_dynamic(): - if not bypass_mode: - logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode") - bypass_mode = True - offloading = True - elif offloading: - if not bypass_mode: - logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message") # Prepare latents and compute counts + latents_dtype = dtype if dtype not in (None,) else torch.bfloat16 latents, num_images, multi_res = _prepare_latents_and_count( - latents, dtype, bucket_mode + latents, latents_dtype, bucket_mode ) # Validate and expand conditioning @@ -1201,6 +1224,7 @@ class TrainLoraNode(io.ComfyNode): seed=seed, training_dtype=dtype, bucket_latents=latents, + use_grad_scaler=use_grad_scaler, ) else: train_sampler = TrainSampler( @@ -1213,6 +1237,7 @@ class TrainLoraNode(io.ComfyNode): seed=seed, training_dtype=dtype, real_dataset=latents if multi_res else None, + use_grad_scaler=use_grad_scaler, ) # Setup guider