From 3a82bd15e788c97dbf1db3534824ca35fe642817 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 23 Mar 2026 19:32:21 +0800 Subject: [PATCH] Add dev run mode for vram profiling or debugging --- comfy_extras/nodes_train.py | 140 ++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 0ad0acee6..e9b95e414 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -140,6 +140,7 @@ class TrainSampler(comfy.samplers.Sampler): real_dataset=None, bucket_latents=None, use_grad_scaler=False, + dev_run=False, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -156,6 +157,7 @@ class TrainSampler(comfy.samplers.Sampler): ) # GradScaler for fp16 training self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None + self.dev_run = dev_run # Precompute bucket offsets and weights for sampling if bucket_latents is not None: self._init_bucket_data(bucket_latents) @@ -186,6 +188,129 @@ class TrainSampler(comfy.samplers.Sampler): extra_args, dataset_size, bwd=True, + ): + if self.dev_run: + return self._fwd_bwd_dev( + model_wrap, batch_sigmas, batch_noise, batch_latent, + cond, indicies, extra_args, dataset_size, bwd, + ) + return self._fwd_bwd_impl( + model_wrap, batch_sigmas, batch_noise, batch_latent, + cond, indicies, extra_args, dataset_size, bwd, + ) + + @staticmethod + def _vram_info(tag): + """Log VRAM usage from both PyTorch allocator and actual GPU.""" + torch.cuda.synchronize() + allocated = torch.cuda.memory_allocated() + reserved = torch.cuda.memory_reserved() + peak = torch.cuda.max_memory_allocated() + free, total = torch.cuda.mem_get_info() + gpu_used = total - free + logging.info( + f"[DevRun] {tag}\n" + f" PyTorch allocated: {allocated / 1024**2:.1f} MB | " + f"reserved: {reserved / 1024**2:.1f} MB | " + f"peak allocated: {peak / 1024**2:.1f} MB\n" + f" GPU real usage: {gpu_used / 1024**2:.1f} MB / {total / 1024**2:.1f} MB | " + f"non-PyTorch: {(gpu_used - reserved) / 1024**2:.1f} MB" + ) + + def _fwd_bwd_dev( + self, + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd, + ): + """Wraps fwd_bwd with CUDA memory profiling for dev_run mode.""" + output_dir = folder_paths.get_output_directory() + snapshot_path = os.path.join(output_dir, "dev_run_memory_snapshot.pkl") + fwd_args = (model_wrap, batch_sigmas, batch_noise, batch_latent, + cond, indicies, extra_args, dataset_size) + + # ── Phase 0: no_grad forward-only reference ── + logging.info("[DevRun] ═══ Phase 0: no_grad forward-only reference ═══") + torch.cuda.reset_peak_memory_stats() + self._vram_info("Before no_grad fwd") + with torch.no_grad(): + self._fwd_bwd_impl(*fwd_args, bwd=False) + self._vram_info("After no_grad fwd (activations freed)") + logging.info("[DevRun] ═══ End Phase 0 ═══\n") + torch.cuda.empty_cache() + + # ── Phase 1: forward pass (with grad, no backward) ── + logging.info("[DevRun] ═══ Phase 1: forward pass (with grad) ═══") + torch.cuda.reset_peak_memory_stats() + self._vram_info("Before fwd") + + # Record memory history with Python-only stacks (works on Windows) + torch.cuda.memory._record_memory_history(max_entries=100000, stacks="python") + + # Inline the forward part of _fwd_bwd_impl so we can measure before backward + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, batch_noise, batch_latent, False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False, + ) + model_wrap.conds["positive"] = [cond[i] for i in indicies] + batch_extra_args = make_batch_extra_option_dict( + extra_args, indicies, full_size=dataset_size + ) + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap( + xt.requires_grad_(True), + batch_sigmas.requires_grad_(True), + **batch_extra_args, + ) + loss = self.loss_fn(x0_pred.float(), x0.float()) + + self._vram_info("After fwd (autograd graph alive)") + + # ── Phase 2: backward pass ── + logging.info("[DevRun] ═══ Phase 2: backward pass ═══") + torch.cuda.reset_peak_memory_stats() + + if bwd: + bwd_loss = loss / self.grad_acc + if self.grad_scaler is not None: + self.grad_scaler.scale(bwd_loss).backward() + else: + bwd_loss.backward() + + self._vram_info("After bwd (grads computed)") + + # ── Dump snapshot ── + torch.cuda.memory._dump_snapshot(snapshot_path) + torch.cuda.memory._record_memory_history(enabled=None) + logging.info( + f"[DevRun] Memory snapshot saved to: {snapshot_path}\n" + f" → Visualize at https://pytorch.org/memory_viz (drag in the .pkl file)" + ) + + return loss + + def _fwd_bwd_impl( + self, + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, ): xt = model_wrap.inner_model.model_sampling.noise_scaling( batch_sigmas, batch_noise, batch_latent, False @@ -1069,6 +1194,11 @@ class TrainLoraNode(io.ComfyNode): default=False, tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.", ), + io.Boolean.Input( + "dev_run", + default=False, + tooltip="Developer profiling mode. Forces batch_size=1, steps=1, records CUDA memory history during fwd_bwd, and exports a memory snapshot to the output folder.", + ), ], outputs=[ io.Custom("LORA_MODEL").Output( @@ -1104,6 +1234,7 @@ class TrainLoraNode(io.ComfyNode): existing_lora, bucket_mode, bypass_mode, + dev_run, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -1124,6 +1255,13 @@ class TrainLoraNode(io.ComfyNode): existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] bypass_mode = bypass_mode[0] + dev_run = dev_run[0] + + # Dev run mode: force batch_size=1, steps=1 for memory profiling + if dev_run: + logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling") + batch_size = 1 + steps = 1 # Process latents based on mode if bucket_mode: @@ -1228,6 +1366,7 @@ class TrainLoraNode(io.ComfyNode): training_dtype=dtype, bucket_latents=latents, use_grad_scaler=use_grad_scaler, + dev_run=dev_run, ) else: train_sampler = TrainSampler( @@ -1241,6 +1380,7 @@ class TrainLoraNode(io.ComfyNode): training_dtype=dtype, real_dataset=latents if multi_res else None, use_grad_scaler=use_grad_scaler, + dev_run=dev_run, ) # Setup guider