Dev mode by args

This commit is contained in:
Kohaku-Blueleaf 2026-03-31 15:55:30 +08:00
parent 3a82bd15e7
commit 57aa65c16f
3 changed files with 65 additions and 26 deletions

View File

@ -237,6 +237,8 @@ database_default_path = os.path.abspath(
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
parser.add_argument("--dev-mode", action="store_true", help="Enable developer mode. Activates trainer VRAM profiling (forces batch_size=1, steps=1) and verbose debug logging for weight adapter systems.")
if comfy.options.args_parsing:
args = parser.parse_args()
else:

View File

@ -22,13 +22,56 @@ import torch
import torch.nn as nn
import comfy.model_management
from comfy.cli_args import args
from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection
def _dev_log(msg: str):
"""Log debug message only when --dev-mode is enabled."""
if args.dev_mode:
logging.info(msg)
# Type alias for adapters that support bypass mode
BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase]
class _RecomputeH(torch.autograd.Function):
"""Recomputes adapter.h() during backward to avoid saving intermediates.
Forward: runs h() under no_grad, saves only the input x.
Backward: recomputes h() with enable_grad, backward through it.
Adapter params receive gradients via direct .grad accumulation.
"""
@staticmethod
def forward(ctx, x, h_fn):
ctx.save_for_backward(x)
ctx.h_fn = h_fn
ctx.fwd_device = x.device
ctx.fwd_dtype = x.dtype
with torch.no_grad():
return h_fn(x, None)
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_out):
x, = ctx.saved_tensors
h_fn = ctx.h_fn
ctx.h_fn = None
with torch.enable_grad(), torch.autocast(
ctx.fwd_device.type, dtype=ctx.fwd_dtype
):
x_d = x.detach().requires_grad_(True)
y = h_fn(x_d, None)
y.backward(grad_out)
grad_x = x_d.grad
del y, x_d, h_fn
return grad_x, None
def get_module_type_info(module: nn.Module) -> dict:
"""
Determine module type and extract conv parameters from module class.
@ -171,13 +214,13 @@ class BypassForwardHook:
# Default bypass: g(f(x) + h(x, f(x)))
base_out = self.original_forward(x, *args, **kwargs)
h_out = self.adapter.h(x, base_out)
h_out = _RecomputeH.apply(x, self.adapter.h)
return self.adapter.g(base_out + h_out)
def inject(self):
"""Replace module forward with bypass version."""
if self.original_forward is not None:
logging.debug(
_dev_log(
f"[BypassHook] Already injected for {type(self.module).__name__}"
)
return # Already injected
@ -200,7 +243,7 @@ class BypassForwardHook:
self.original_forward = self.module.forward
self.module.forward = self._bypass_forward
logging.debug(
_dev_log(
f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})"
)
@ -217,7 +260,7 @@ class BypassForwardHook:
if isinstance(adapter, nn.Module):
# In training mode we don't touch dtype as trainer will handle it
adapter.to(device=device)
logging.debug(
_dev_log(
f"[BypassHook] Moved training adapter (nn.Module) to {device}"
)
return
@ -246,17 +289,17 @@ class BypassForwardHook:
else:
adapter.weights = weights.to(device=device)
logging.debug(f"[BypassHook] Moved adapter weights to {device}")
_dev_log(f"[BypassHook] Moved adapter weights to {device}")
def eject(self):
"""Restore original module forward."""
if self.original_forward is None:
logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}")
_dev_log(f"[BypassHook] Not injected for {type(self.module).__name__}")
return # Not injected
self.module.forward = self.original_forward
self.original_forward = None
logging.debug(
_dev_log(
f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}"
)
@ -301,12 +344,12 @@ class BypassInjectionManager:
module_key = key
if module_key.endswith(".weight"):
module_key = module_key[:-7]
logging.debug(
_dev_log(
f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}"
)
self.adapters[module_key] = (adapter, strength)
logging.debug(
_dev_log(
f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})"
)
@ -324,7 +367,7 @@ class BypassInjectionManager:
module = module[int(part)]
else:
module = getattr(module, part)
logging.debug(
_dev_log(
f"[BypassManager] Found module for key {key}: {type(module).__name__}"
)
return module
@ -347,13 +390,13 @@ class BypassInjectionManager:
"""
self.hooks.clear()
logging.debug(
_dev_log(
f"[BypassManager] create_injections called with {len(self.adapters)} adapters"
)
logging.debug(f"[BypassManager] Model type: {type(model).__name__}")
_dev_log(f"[BypassManager] Model type: {type(model).__name__}")
for key, (adapter, strength) in self.adapters.items():
logging.debug(f"[BypassManager] Looking for module: {key}")
_dev_log(f"[BypassManager] Looking for module: {key}")
module = self._get_module_by_key(model, key)
if module is None:
@ -366,27 +409,27 @@ class BypassInjectionManager:
)
continue
logging.debug(
_dev_log(
f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})"
)
hook = BypassForwardHook(module, adapter, multiplier=strength)
self.hooks.append(hook)
logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks")
_dev_log(f"[BypassManager] Created {len(self.hooks)} hooks")
# Create single injection that manages all hooks
def inject_all(model_patcher):
logging.debug(
_dev_log(
f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks"
)
for hook in self.hooks:
hook.inject()
logging.debug(
_dev_log(
f"[BypassManager] Injected hook for {type(hook.module).__name__}"
)
def eject_all(model_patcher):
logging.debug(
_dev_log(
f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks"
)
for hook in self.hooks:

View File

@ -1194,11 +1194,6 @@ class TrainLoraNode(io.ComfyNode):
default=False,
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
),
io.Boolean.Input(
"dev_run",
default=False,
tooltip="Developer profiling mode. Forces batch_size=1, steps=1, records CUDA memory history during fwd_bwd, and exports a memory snapshot to the output folder.",
),
],
outputs=[
io.Custom("LORA_MODEL").Output(
@ -1234,7 +1229,6 @@ class TrainLoraNode(io.ComfyNode):
existing_lora,
bucket_mode,
bypass_mode,
dev_run,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
@ -1255,9 +1249,9 @@ class TrainLoraNode(io.ComfyNode):
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
dev_run = dev_run[0]
# Dev run mode: force batch_size=1, steps=1 for memory profiling
# Dev run mode (--dev-mode): force batch_size=1, steps=1 for memory profiling
dev_run = args.dev_mode
if dev_run:
logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling")
batch_size = 1