diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 13612175e..da103af06 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -237,6 +237,8 @@ database_default_path = os.path.abspath( parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") +parser.add_argument("--dev-mode", action="store_true", help="Enable developer mode. Activates trainer VRAM profiling (forces batch_size=1, steps=1) and verbose debug logging for weight adapter systems.") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py index b9d5ec7d9..5c6cb736a 100644 --- a/comfy/weight_adapter/bypass.py +++ b/comfy/weight_adapter/bypass.py @@ -22,13 +22,56 @@ import torch import torch.nn as nn import comfy.model_management +from comfy.cli_args import args from .base import WeightAdapterBase, WeightAdapterTrainBase from comfy.patcher_extension import PatcherInjection + +def _dev_log(msg: str): + """Log debug message only when --dev-mode is enabled.""" + if args.dev_mode: + logging.info(msg) + # Type alias for adapters that support bypass mode BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase] +class _RecomputeH(torch.autograd.Function): + """Recomputes adapter.h() during backward to avoid saving intermediates. + + Forward: runs h() under no_grad, saves only the input x. + Backward: recomputes h() with enable_grad, backward through it. + Adapter params receive gradients via direct .grad accumulation. + """ + + @staticmethod + def forward(ctx, x, h_fn): + ctx.save_for_backward(x) + ctx.h_fn = h_fn + ctx.fwd_device = x.device + ctx.fwd_dtype = x.dtype + with torch.no_grad(): + return h_fn(x, None) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_out): + x, = ctx.saved_tensors + h_fn = ctx.h_fn + ctx.h_fn = None + + with torch.enable_grad(), torch.autocast( + ctx.fwd_device.type, dtype=ctx.fwd_dtype + ): + x_d = x.detach().requires_grad_(True) + y = h_fn(x_d, None) + y.backward(grad_out) + + grad_x = x_d.grad + del y, x_d, h_fn + return grad_x, None + + def get_module_type_info(module: nn.Module) -> dict: """ Determine module type and extract conv parameters from module class. @@ -171,13 +214,13 @@ class BypassForwardHook: # Default bypass: g(f(x) + h(x, f(x))) base_out = self.original_forward(x, *args, **kwargs) - h_out = self.adapter.h(x, base_out) + h_out = _RecomputeH.apply(x, self.adapter.h) return self.adapter.g(base_out + h_out) def inject(self): """Replace module forward with bypass version.""" if self.original_forward is not None: - logging.debug( + _dev_log( f"[BypassHook] Already injected for {type(self.module).__name__}" ) return # Already injected @@ -200,7 +243,7 @@ class BypassForwardHook: self.original_forward = self.module.forward self.module.forward = self._bypass_forward - logging.debug( + _dev_log( f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})" ) @@ -217,7 +260,7 @@ class BypassForwardHook: if isinstance(adapter, nn.Module): # In training mode we don't touch dtype as trainer will handle it adapter.to(device=device) - logging.debug( + _dev_log( f"[BypassHook] Moved training adapter (nn.Module) to {device}" ) return @@ -246,17 +289,17 @@ class BypassForwardHook: else: adapter.weights = weights.to(device=device) - logging.debug(f"[BypassHook] Moved adapter weights to {device}") + _dev_log(f"[BypassHook] Moved adapter weights to {device}") def eject(self): """Restore original module forward.""" if self.original_forward is None: - logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}") + _dev_log(f"[BypassHook] Not injected for {type(self.module).__name__}") return # Not injected self.module.forward = self.original_forward self.original_forward = None - logging.debug( + _dev_log( f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}" ) @@ -301,12 +344,12 @@ class BypassInjectionManager: module_key = key if module_key.endswith(".weight"): module_key = module_key[:-7] - logging.debug( + _dev_log( f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}" ) self.adapters[module_key] = (adapter, strength) - logging.debug( + _dev_log( f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})" ) @@ -324,7 +367,7 @@ class BypassInjectionManager: module = module[int(part)] else: module = getattr(module, part) - logging.debug( + _dev_log( f"[BypassManager] Found module for key {key}: {type(module).__name__}" ) return module @@ -347,13 +390,13 @@ class BypassInjectionManager: """ self.hooks.clear() - logging.debug( + _dev_log( f"[BypassManager] create_injections called with {len(self.adapters)} adapters" ) - logging.debug(f"[BypassManager] Model type: {type(model).__name__}") + _dev_log(f"[BypassManager] Model type: {type(model).__name__}") for key, (adapter, strength) in self.adapters.items(): - logging.debug(f"[BypassManager] Looking for module: {key}") + _dev_log(f"[BypassManager] Looking for module: {key}") module = self._get_module_by_key(model, key) if module is None: @@ -366,27 +409,27 @@ class BypassInjectionManager: ) continue - logging.debug( + _dev_log( f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})" ) hook = BypassForwardHook(module, adapter, multiplier=strength) self.hooks.append(hook) - logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks") + _dev_log(f"[BypassManager] Created {len(self.hooks)} hooks") # Create single injection that manages all hooks def inject_all(model_patcher): - logging.debug( + _dev_log( f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks" ) for hook in self.hooks: hook.inject() - logging.debug( + _dev_log( f"[BypassManager] Injected hook for {type(hook.module).__name__}" ) def eject_all(model_patcher): - logging.debug( + _dev_log( f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks" ) for hook in self.hooks: diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index e9b95e414..001141204 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1194,11 +1194,6 @@ class TrainLoraNode(io.ComfyNode): default=False, tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.", ), - io.Boolean.Input( - "dev_run", - default=False, - tooltip="Developer profiling mode. Forces batch_size=1, steps=1, records CUDA memory history during fwd_bwd, and exports a memory snapshot to the output folder.", - ), ], outputs=[ io.Custom("LORA_MODEL").Output( @@ -1234,7 +1229,6 @@ class TrainLoraNode(io.ComfyNode): existing_lora, bucket_mode, bypass_mode, - dev_run, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -1255,9 +1249,9 @@ class TrainLoraNode(io.ComfyNode): existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] bypass_mode = bypass_mode[0] - dev_run = dev_run[0] - # Dev run mode: force batch_size=1, steps=1 for memory profiling + # Dev run mode (--dev-mode): force batch_size=1, steps=1 for memory profiling + dev_run = args.dev_mode if dev_run: logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling") batch_size = 1