Dev mode by args

This commit is contained in:
Kohaku-Blueleaf 2026-03-31 15:55:30 +08:00
parent 3a82bd15e7
commit 57aa65c16f
3 changed files with 65 additions and 26 deletions

View File

@ -237,6 +237,8 @@ database_default_path = os.path.abspath(
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
parser.add_argument("--dev-mode", action="store_true", help="Enable developer mode. Activates trainer VRAM profiling (forces batch_size=1, steps=1) and verbose debug logging for weight adapter systems.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()
else: else:

View File

@ -22,13 +22,56 @@ import torch
import torch.nn as nn import torch.nn as nn
import comfy.model_management import comfy.model_management
from comfy.cli_args import args
from .base import WeightAdapterBase, WeightAdapterTrainBase from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection from comfy.patcher_extension import PatcherInjection
def _dev_log(msg: str):
"""Log debug message only when --dev-mode is enabled."""
if args.dev_mode:
logging.info(msg)
# Type alias for adapters that support bypass mode # Type alias for adapters that support bypass mode
BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase] BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase]
class _RecomputeH(torch.autograd.Function):
"""Recomputes adapter.h() during backward to avoid saving intermediates.
Forward: runs h() under no_grad, saves only the input x.
Backward: recomputes h() with enable_grad, backward through it.
Adapter params receive gradients via direct .grad accumulation.
"""
@staticmethod
def forward(ctx, x, h_fn):
ctx.save_for_backward(x)
ctx.h_fn = h_fn
ctx.fwd_device = x.device
ctx.fwd_dtype = x.dtype
with torch.no_grad():
return h_fn(x, None)
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_out):
x, = ctx.saved_tensors
h_fn = ctx.h_fn
ctx.h_fn = None
with torch.enable_grad(), torch.autocast(
ctx.fwd_device.type, dtype=ctx.fwd_dtype
):
x_d = x.detach().requires_grad_(True)
y = h_fn(x_d, None)
y.backward(grad_out)
grad_x = x_d.grad
del y, x_d, h_fn
return grad_x, None
def get_module_type_info(module: nn.Module) -> dict: def get_module_type_info(module: nn.Module) -> dict:
""" """
Determine module type and extract conv parameters from module class. Determine module type and extract conv parameters from module class.
@ -171,13 +214,13 @@ class BypassForwardHook:
# Default bypass: g(f(x) + h(x, f(x))) # Default bypass: g(f(x) + h(x, f(x)))
base_out = self.original_forward(x, *args, **kwargs) base_out = self.original_forward(x, *args, **kwargs)
h_out = self.adapter.h(x, base_out) h_out = _RecomputeH.apply(x, self.adapter.h)
return self.adapter.g(base_out + h_out) return self.adapter.g(base_out + h_out)
def inject(self): def inject(self):
"""Replace module forward with bypass version.""" """Replace module forward with bypass version."""
if self.original_forward is not None: if self.original_forward is not None:
logging.debug( _dev_log(
f"[BypassHook] Already injected for {type(self.module).__name__}" f"[BypassHook] Already injected for {type(self.module).__name__}"
) )
return # Already injected return # Already injected
@ -200,7 +243,7 @@ class BypassForwardHook:
self.original_forward = self.module.forward self.original_forward = self.module.forward
self.module.forward = self._bypass_forward self.module.forward = self._bypass_forward
logging.debug( _dev_log(
f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})" f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})"
) )
@ -217,7 +260,7 @@ class BypassForwardHook:
if isinstance(adapter, nn.Module): if isinstance(adapter, nn.Module):
# In training mode we don't touch dtype as trainer will handle it # In training mode we don't touch dtype as trainer will handle it
adapter.to(device=device) adapter.to(device=device)
logging.debug( _dev_log(
f"[BypassHook] Moved training adapter (nn.Module) to {device}" f"[BypassHook] Moved training adapter (nn.Module) to {device}"
) )
return return
@ -246,17 +289,17 @@ class BypassForwardHook:
else: else:
adapter.weights = weights.to(device=device) adapter.weights = weights.to(device=device)
logging.debug(f"[BypassHook] Moved adapter weights to {device}") _dev_log(f"[BypassHook] Moved adapter weights to {device}")
def eject(self): def eject(self):
"""Restore original module forward.""" """Restore original module forward."""
if self.original_forward is None: if self.original_forward is None:
logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}") _dev_log(f"[BypassHook] Not injected for {type(self.module).__name__}")
return # Not injected return # Not injected
self.module.forward = self.original_forward self.module.forward = self.original_forward
self.original_forward = None self.original_forward = None
logging.debug( _dev_log(
f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}" f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}"
) )
@ -301,12 +344,12 @@ class BypassInjectionManager:
module_key = key module_key = key
if module_key.endswith(".weight"): if module_key.endswith(".weight"):
module_key = module_key[:-7] module_key = module_key[:-7]
logging.debug( _dev_log(
f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}" f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}"
) )
self.adapters[module_key] = (adapter, strength) self.adapters[module_key] = (adapter, strength)
logging.debug( _dev_log(
f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})" f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})"
) )
@ -324,7 +367,7 @@ class BypassInjectionManager:
module = module[int(part)] module = module[int(part)]
else: else:
module = getattr(module, part) module = getattr(module, part)
logging.debug( _dev_log(
f"[BypassManager] Found module for key {key}: {type(module).__name__}" f"[BypassManager] Found module for key {key}: {type(module).__name__}"
) )
return module return module
@ -347,13 +390,13 @@ class BypassInjectionManager:
""" """
self.hooks.clear() self.hooks.clear()
logging.debug( _dev_log(
f"[BypassManager] create_injections called with {len(self.adapters)} adapters" f"[BypassManager] create_injections called with {len(self.adapters)} adapters"
) )
logging.debug(f"[BypassManager] Model type: {type(model).__name__}") _dev_log(f"[BypassManager] Model type: {type(model).__name__}")
for key, (adapter, strength) in self.adapters.items(): for key, (adapter, strength) in self.adapters.items():
logging.debug(f"[BypassManager] Looking for module: {key}") _dev_log(f"[BypassManager] Looking for module: {key}")
module = self._get_module_by_key(model, key) module = self._get_module_by_key(model, key)
if module is None: if module is None:
@ -366,27 +409,27 @@ class BypassInjectionManager:
) )
continue continue
logging.debug( _dev_log(
f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})" f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})"
) )
hook = BypassForwardHook(module, adapter, multiplier=strength) hook = BypassForwardHook(module, adapter, multiplier=strength)
self.hooks.append(hook) self.hooks.append(hook)
logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks") _dev_log(f"[BypassManager] Created {len(self.hooks)} hooks")
# Create single injection that manages all hooks # Create single injection that manages all hooks
def inject_all(model_patcher): def inject_all(model_patcher):
logging.debug( _dev_log(
f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks" f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks"
) )
for hook in self.hooks: for hook in self.hooks:
hook.inject() hook.inject()
logging.debug( _dev_log(
f"[BypassManager] Injected hook for {type(hook.module).__name__}" f"[BypassManager] Injected hook for {type(hook.module).__name__}"
) )
def eject_all(model_patcher): def eject_all(model_patcher):
logging.debug( _dev_log(
f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks" f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks"
) )
for hook in self.hooks: for hook in self.hooks:

View File

@ -1194,11 +1194,6 @@ class TrainLoraNode(io.ComfyNode):
default=False, default=False,
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.", tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
), ),
io.Boolean.Input(
"dev_run",
default=False,
tooltip="Developer profiling mode. Forces batch_size=1, steps=1, records CUDA memory history during fwd_bwd, and exports a memory snapshot to the output folder.",
),
], ],
outputs=[ outputs=[
io.Custom("LORA_MODEL").Output( io.Custom("LORA_MODEL").Output(
@ -1234,7 +1229,6 @@ class TrainLoraNode(io.ComfyNode):
existing_lora, existing_lora,
bucket_mode, bucket_mode,
bypass_mode, bypass_mode,
dev_run,
): ):
# Extract scalars from lists (due to is_input_list=True) # Extract scalars from lists (due to is_input_list=True)
model = model[0] model = model[0]
@ -1255,9 +1249,9 @@ class TrainLoraNode(io.ComfyNode):
existing_lora = existing_lora[0] existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0] bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0] bypass_mode = bypass_mode[0]
dev_run = dev_run[0]
# Dev run mode: force batch_size=1, steps=1 for memory profiling # Dev run mode (--dev-mode): force batch_size=1, steps=1 for memory profiling
dev_run = args.dev_mode
if dev_run: if dev_run:
logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling") logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling")
batch_size = 1 batch_size = 1