diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dbaadf723..87a9d7e9b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -239,6 +239,8 @@ database_default_path = os.path.abspath( parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") +parser.add_argument("--dev-mode", action="store_true", help="Enable developer mode. Activates trainer VRAM profiling (forces batch_size=1, steps=1) and verbose debug logging for weight adapter systems.") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py index b9d5ec7d9..5c6cb736a 100644 --- a/comfy/weight_adapter/bypass.py +++ b/comfy/weight_adapter/bypass.py @@ -22,13 +22,56 @@ import torch import torch.nn as nn import comfy.model_management +from comfy.cli_args import args from .base import WeightAdapterBase, WeightAdapterTrainBase from comfy.patcher_extension import PatcherInjection + +def _dev_log(msg: str): + """Log debug message only when --dev-mode is enabled.""" + if args.dev_mode: + logging.info(msg) + # Type alias for adapters that support bypass mode BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase] +class _RecomputeH(torch.autograd.Function): + """Recomputes adapter.h() during backward to avoid saving intermediates. + + Forward: runs h() under no_grad, saves only the input x. + Backward: recomputes h() with enable_grad, backward through it. + Adapter params receive gradients via direct .grad accumulation. + """ + + @staticmethod + def forward(ctx, x, h_fn): + ctx.save_for_backward(x) + ctx.h_fn = h_fn + ctx.fwd_device = x.device + ctx.fwd_dtype = x.dtype + with torch.no_grad(): + return h_fn(x, None) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_out): + x, = ctx.saved_tensors + h_fn = ctx.h_fn + ctx.h_fn = None + + with torch.enable_grad(), torch.autocast( + ctx.fwd_device.type, dtype=ctx.fwd_dtype + ): + x_d = x.detach().requires_grad_(True) + y = h_fn(x_d, None) + y.backward(grad_out) + + grad_x = x_d.grad + del y, x_d, h_fn + return grad_x, None + + def get_module_type_info(module: nn.Module) -> dict: """ Determine module type and extract conv parameters from module class. @@ -171,13 +214,13 @@ class BypassForwardHook: # Default bypass: g(f(x) + h(x, f(x))) base_out = self.original_forward(x, *args, **kwargs) - h_out = self.adapter.h(x, base_out) + h_out = _RecomputeH.apply(x, self.adapter.h) return self.adapter.g(base_out + h_out) def inject(self): """Replace module forward with bypass version.""" if self.original_forward is not None: - logging.debug( + _dev_log( f"[BypassHook] Already injected for {type(self.module).__name__}" ) return # Already injected @@ -200,7 +243,7 @@ class BypassForwardHook: self.original_forward = self.module.forward self.module.forward = self._bypass_forward - logging.debug( + _dev_log( f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})" ) @@ -217,7 +260,7 @@ class BypassForwardHook: if isinstance(adapter, nn.Module): # In training mode we don't touch dtype as trainer will handle it adapter.to(device=device) - logging.debug( + _dev_log( f"[BypassHook] Moved training adapter (nn.Module) to {device}" ) return @@ -246,17 +289,17 @@ class BypassForwardHook: else: adapter.weights = weights.to(device=device) - logging.debug(f"[BypassHook] Moved adapter weights to {device}") + _dev_log(f"[BypassHook] Moved adapter weights to {device}") def eject(self): """Restore original module forward.""" if self.original_forward is None: - logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}") + _dev_log(f"[BypassHook] Not injected for {type(self.module).__name__}") return # Not injected self.module.forward = self.original_forward self.original_forward = None - logging.debug( + _dev_log( f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}" ) @@ -301,12 +344,12 @@ class BypassInjectionManager: module_key = key if module_key.endswith(".weight"): module_key = module_key[:-7] - logging.debug( + _dev_log( f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}" ) self.adapters[module_key] = (adapter, strength) - logging.debug( + _dev_log( f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})" ) @@ -324,7 +367,7 @@ class BypassInjectionManager: module = module[int(part)] else: module = getattr(module, part) - logging.debug( + _dev_log( f"[BypassManager] Found module for key {key}: {type(module).__name__}" ) return module @@ -347,13 +390,13 @@ class BypassInjectionManager: """ self.hooks.clear() - logging.debug( + _dev_log( f"[BypassManager] create_injections called with {len(self.adapters)} adapters" ) - logging.debug(f"[BypassManager] Model type: {type(model).__name__}") + _dev_log(f"[BypassManager] Model type: {type(model).__name__}") for key, (adapter, strength) in self.adapters.items(): - logging.debug(f"[BypassManager] Looking for module: {key}") + _dev_log(f"[BypassManager] Looking for module: {key}") module = self._get_module_by_key(model, key) if module is None: @@ -366,27 +409,27 @@ class BypassInjectionManager: ) continue - logging.debug( + _dev_log( f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})" ) hook = BypassForwardHook(module, adapter, multiplier=strength) self.hooks.append(hook) - logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks") + _dev_log(f"[BypassManager] Created {len(self.hooks)} hooks") # Create single injection that manages all hooks def inject_all(model_patcher): - logging.debug( + _dev_log( f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks" ) for hook in self.hooks: hook.inject() - logging.debug( + _dev_log( f"[BypassManager] Injected hook for {type(hook.module).__name__}" ) def eject_all(model_patcher): - logging.debug( + _dev_log( f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks" ) for hook in self.hooks: diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 0616dfc2d..a0f941911 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -140,6 +140,7 @@ class TrainSampler(comfy.samplers.Sampler): real_dataset=None, bucket_latents=None, use_grad_scaler=False, + dev_run=False, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -156,6 +157,7 @@ class TrainSampler(comfy.samplers.Sampler): ) # GradScaler for fp16 training self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None + self.dev_run = dev_run # Precompute bucket offsets and weights for sampling if bucket_latents is not None: self._init_bucket_data(bucket_latents) @@ -186,6 +188,129 @@ class TrainSampler(comfy.samplers.Sampler): extra_args, dataset_size, bwd=True, + ): + if self.dev_run: + return self._fwd_bwd_dev( + model_wrap, batch_sigmas, batch_noise, batch_latent, + cond, indicies, extra_args, dataset_size, bwd, + ) + return self._fwd_bwd_impl( + model_wrap, batch_sigmas, batch_noise, batch_latent, + cond, indicies, extra_args, dataset_size, bwd, + ) + + @staticmethod + def _vram_info(tag): + """Log VRAM usage from both PyTorch allocator and actual GPU.""" + torch.cuda.synchronize() + allocated = torch.cuda.memory_allocated() + reserved = torch.cuda.memory_reserved() + peak = torch.cuda.max_memory_allocated() + free, total = torch.cuda.mem_get_info() + gpu_used = total - free + logging.info( + f"[DevRun] {tag}\n" + f" PyTorch allocated: {allocated / 1024**2:.1f} MB | " + f"reserved: {reserved / 1024**2:.1f} MB | " + f"peak allocated: {peak / 1024**2:.1f} MB\n" + f" GPU real usage: {gpu_used / 1024**2:.1f} MB / {total / 1024**2:.1f} MB | " + f"non-PyTorch: {(gpu_used - reserved) / 1024**2:.1f} MB" + ) + + def _fwd_bwd_dev( + self, + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd, + ): + """Wraps fwd_bwd with CUDA memory profiling for dev_run mode.""" + output_dir = folder_paths.get_output_directory() + snapshot_path = os.path.join(output_dir, "dev_run_memory_snapshot.pkl") + fwd_args = (model_wrap, batch_sigmas, batch_noise, batch_latent, + cond, indicies, extra_args, dataset_size) + + # ── Phase 0: no_grad forward-only reference ── + logging.info("[DevRun] ═══ Phase 0: no_grad forward-only reference ═══") + torch.cuda.reset_peak_memory_stats() + self._vram_info("Before no_grad fwd") + with torch.no_grad(): + self._fwd_bwd_impl(*fwd_args, bwd=False) + self._vram_info("After no_grad fwd (activations freed)") + logging.info("[DevRun] ═══ End Phase 0 ═══\n") + torch.cuda.empty_cache() + + # ── Phase 1: forward pass (with grad, no backward) ── + logging.info("[DevRun] ═══ Phase 1: forward pass (with grad) ═══") + torch.cuda.reset_peak_memory_stats() + self._vram_info("Before fwd") + + # Record memory history with Python-only stacks (works on Windows) + torch.cuda.memory._record_memory_history(max_entries=100000, stacks="python") + + # Inline the forward part of _fwd_bwd_impl so we can measure before backward + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, batch_noise, batch_latent, False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False, + ) + model_wrap.conds["positive"] = [cond[i] for i in indicies] + batch_extra_args = make_batch_extra_option_dict( + extra_args, indicies, full_size=dataset_size + ) + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap( + xt.requires_grad_(True), + batch_sigmas.requires_grad_(True), + **batch_extra_args, + ) + loss = self.loss_fn(x0_pred.float(), x0.float()) + + self._vram_info("After fwd (autograd graph alive)") + + # ── Phase 2: backward pass ── + logging.info("[DevRun] ═══ Phase 2: backward pass ═══") + torch.cuda.reset_peak_memory_stats() + + if bwd: + bwd_loss = loss / self.grad_acc + if self.grad_scaler is not None: + self.grad_scaler.scale(bwd_loss).backward() + else: + bwd_loss.backward() + + self._vram_info("After bwd (grads computed)") + + # ── Dump snapshot ── + torch.cuda.memory._dump_snapshot(snapshot_path) + torch.cuda.memory._record_memory_history(enabled=None) + logging.info( + f"[DevRun] Memory snapshot saved to: {snapshot_path}\n" + f" → Visualize at https://pytorch.org/memory_viz (drag in the .pkl file)" + ) + + return loss + + def _fwd_bwd_impl( + self, + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, ): xt = model_wrap.inner_model.model_sampling.noise_scaling( batch_sigmas, batch_noise, batch_latent, False @@ -1132,6 +1257,12 @@ class TrainLoraNode(io.ComfyNode): bucket_mode = bucket_mode[0] bypass_mode = bypass_mode[0] + # Dev run mode (--dev-mode): force batch_size=1, steps=1 for memory profiling + dev_run = args.dev_mode + if dev_run: + logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling") + batch_size = 1 + steps = 1 comfy.model_management.training_fp8_bwd = quantized_backward # Process latents based on mode @@ -1240,6 +1371,7 @@ class TrainLoraNode(io.ComfyNode): training_dtype=dtype, bucket_latents=latents, use_grad_scaler=use_grad_scaler, + dev_run=dev_run, ) else: train_sampler = TrainSampler( @@ -1253,6 +1385,7 @@ class TrainLoraNode(io.ComfyNode): training_dtype=dtype, real_dataset=latents if multi_res else None, use_grad_scaler=use_grad_scaler, + dev_run=dev_run, ) # Setup guider