mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-21 07:52:39 +08:00
Merge 4290dd82a3 into 3086026401
This commit is contained in:
commit
23c0c58ac2
@ -239,6 +239,8 @@ database_default_path = os.path.abspath(
|
|||||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||||
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
||||||
|
|
||||||
|
parser.add_argument("--dev-mode", action="store_true", help="Enable developer mode. Activates trainer VRAM profiling (forces batch_size=1, steps=1) and verbose debug logging for weight adapter systems.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -22,13 +22,56 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args
|
||||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||||
from comfy.patcher_extension import PatcherInjection
|
from comfy.patcher_extension import PatcherInjection
|
||||||
|
|
||||||
|
|
||||||
|
def _dev_log(msg: str):
|
||||||
|
"""Log debug message only when --dev-mode is enabled."""
|
||||||
|
if args.dev_mode:
|
||||||
|
logging.info(msg)
|
||||||
|
|
||||||
# Type alias for adapters that support bypass mode
|
# Type alias for adapters that support bypass mode
|
||||||
BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase]
|
BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase]
|
||||||
|
|
||||||
|
|
||||||
|
class _RecomputeH(torch.autograd.Function):
|
||||||
|
"""Recomputes adapter.h() during backward to avoid saving intermediates.
|
||||||
|
|
||||||
|
Forward: runs h() under no_grad, saves only the input x.
|
||||||
|
Backward: recomputes h() with enable_grad, backward through it.
|
||||||
|
Adapter params receive gradients via direct .grad accumulation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, h_fn):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
ctx.h_fn = h_fn
|
||||||
|
ctx.fwd_device = x.device
|
||||||
|
ctx.fwd_dtype = x.dtype
|
||||||
|
with torch.no_grad():
|
||||||
|
return h_fn(x, None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.autograd.function.once_differentiable
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
h_fn = ctx.h_fn
|
||||||
|
ctx.h_fn = None
|
||||||
|
|
||||||
|
with torch.enable_grad(), torch.autocast(
|
||||||
|
ctx.fwd_device.type, dtype=ctx.fwd_dtype
|
||||||
|
):
|
||||||
|
x_d = x.detach().requires_grad_(True)
|
||||||
|
y = h_fn(x_d, None)
|
||||||
|
y.backward(grad_out)
|
||||||
|
|
||||||
|
grad_x = x_d.grad
|
||||||
|
del y, x_d, h_fn
|
||||||
|
return grad_x, None
|
||||||
|
|
||||||
|
|
||||||
def get_module_type_info(module: nn.Module) -> dict:
|
def get_module_type_info(module: nn.Module) -> dict:
|
||||||
"""
|
"""
|
||||||
Determine module type and extract conv parameters from module class.
|
Determine module type and extract conv parameters from module class.
|
||||||
@ -171,13 +214,13 @@ class BypassForwardHook:
|
|||||||
|
|
||||||
# Default bypass: g(f(x) + h(x, f(x)))
|
# Default bypass: g(f(x) + h(x, f(x)))
|
||||||
base_out = self.original_forward(x, *args, **kwargs)
|
base_out = self.original_forward(x, *args, **kwargs)
|
||||||
h_out = self.adapter.h(x, base_out)
|
h_out = _RecomputeH.apply(x, self.adapter.h)
|
||||||
return self.adapter.g(base_out + h_out)
|
return self.adapter.g(base_out + h_out)
|
||||||
|
|
||||||
def inject(self):
|
def inject(self):
|
||||||
"""Replace module forward with bypass version."""
|
"""Replace module forward with bypass version."""
|
||||||
if self.original_forward is not None:
|
if self.original_forward is not None:
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassHook] Already injected for {type(self.module).__name__}"
|
f"[BypassHook] Already injected for {type(self.module).__name__}"
|
||||||
)
|
)
|
||||||
return # Already injected
|
return # Already injected
|
||||||
@ -200,7 +243,7 @@ class BypassForwardHook:
|
|||||||
|
|
||||||
self.original_forward = self.module.forward
|
self.original_forward = self.module.forward
|
||||||
self.module.forward = self._bypass_forward
|
self.module.forward = self._bypass_forward
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})"
|
f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -217,7 +260,7 @@ class BypassForwardHook:
|
|||||||
if isinstance(adapter, nn.Module):
|
if isinstance(adapter, nn.Module):
|
||||||
# In training mode we don't touch dtype as trainer will handle it
|
# In training mode we don't touch dtype as trainer will handle it
|
||||||
adapter.to(device=device)
|
adapter.to(device=device)
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassHook] Moved training adapter (nn.Module) to {device}"
|
f"[BypassHook] Moved training adapter (nn.Module) to {device}"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -246,17 +289,17 @@ class BypassForwardHook:
|
|||||||
else:
|
else:
|
||||||
adapter.weights = weights.to(device=device)
|
adapter.weights = weights.to(device=device)
|
||||||
|
|
||||||
logging.debug(f"[BypassHook] Moved adapter weights to {device}")
|
_dev_log(f"[BypassHook] Moved adapter weights to {device}")
|
||||||
|
|
||||||
def eject(self):
|
def eject(self):
|
||||||
"""Restore original module forward."""
|
"""Restore original module forward."""
|
||||||
if self.original_forward is None:
|
if self.original_forward is None:
|
||||||
logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}")
|
_dev_log(f"[BypassHook] Not injected for {type(self.module).__name__}")
|
||||||
return # Not injected
|
return # Not injected
|
||||||
|
|
||||||
self.module.forward = self.original_forward
|
self.module.forward = self.original_forward
|
||||||
self.original_forward = None
|
self.original_forward = None
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}"
|
f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -301,12 +344,12 @@ class BypassInjectionManager:
|
|||||||
module_key = key
|
module_key = key
|
||||||
if module_key.endswith(".weight"):
|
if module_key.endswith(".weight"):
|
||||||
module_key = module_key[:-7]
|
module_key = module_key[:-7]
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}"
|
f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.adapters[module_key] = (adapter, strength)
|
self.adapters[module_key] = (adapter, strength)
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})"
|
f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -324,7 +367,7 @@ class BypassInjectionManager:
|
|||||||
module = module[int(part)]
|
module = module[int(part)]
|
||||||
else:
|
else:
|
||||||
module = getattr(module, part)
|
module = getattr(module, part)
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassManager] Found module for key {key}: {type(module).__name__}"
|
f"[BypassManager] Found module for key {key}: {type(module).__name__}"
|
||||||
)
|
)
|
||||||
return module
|
return module
|
||||||
@ -347,13 +390,13 @@ class BypassInjectionManager:
|
|||||||
"""
|
"""
|
||||||
self.hooks.clear()
|
self.hooks.clear()
|
||||||
|
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassManager] create_injections called with {len(self.adapters)} adapters"
|
f"[BypassManager] create_injections called with {len(self.adapters)} adapters"
|
||||||
)
|
)
|
||||||
logging.debug(f"[BypassManager] Model type: {type(model).__name__}")
|
_dev_log(f"[BypassManager] Model type: {type(model).__name__}")
|
||||||
|
|
||||||
for key, (adapter, strength) in self.adapters.items():
|
for key, (adapter, strength) in self.adapters.items():
|
||||||
logging.debug(f"[BypassManager] Looking for module: {key}")
|
_dev_log(f"[BypassManager] Looking for module: {key}")
|
||||||
module = self._get_module_by_key(model, key)
|
module = self._get_module_by_key(model, key)
|
||||||
|
|
||||||
if module is None:
|
if module is None:
|
||||||
@ -366,27 +409,27 @@ class BypassInjectionManager:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})"
|
f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})"
|
||||||
)
|
)
|
||||||
hook = BypassForwardHook(module, adapter, multiplier=strength)
|
hook = BypassForwardHook(module, adapter, multiplier=strength)
|
||||||
self.hooks.append(hook)
|
self.hooks.append(hook)
|
||||||
|
|
||||||
logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks")
|
_dev_log(f"[BypassManager] Created {len(self.hooks)} hooks")
|
||||||
|
|
||||||
# Create single injection that manages all hooks
|
# Create single injection that manages all hooks
|
||||||
def inject_all(model_patcher):
|
def inject_all(model_patcher):
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks"
|
f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks"
|
||||||
)
|
)
|
||||||
for hook in self.hooks:
|
for hook in self.hooks:
|
||||||
hook.inject()
|
hook.inject()
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassManager] Injected hook for {type(hook.module).__name__}"
|
f"[BypassManager] Injected hook for {type(hook.module).__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def eject_all(model_patcher):
|
def eject_all(model_patcher):
|
||||||
logging.debug(
|
_dev_log(
|
||||||
f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks"
|
f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks"
|
||||||
)
|
)
|
||||||
for hook in self.hooks:
|
for hook in self.hooks:
|
||||||
|
|||||||
@ -140,6 +140,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
real_dataset=None,
|
real_dataset=None,
|
||||||
bucket_latents=None,
|
bucket_latents=None,
|
||||||
use_grad_scaler=False,
|
use_grad_scaler=False,
|
||||||
|
dev_run=False,
|
||||||
):
|
):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
@ -156,6 +157,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
)
|
)
|
||||||
# GradScaler for fp16 training
|
# GradScaler for fp16 training
|
||||||
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||||
|
self.dev_run = dev_run
|
||||||
# Precompute bucket offsets and weights for sampling
|
# Precompute bucket offsets and weights for sampling
|
||||||
if bucket_latents is not None:
|
if bucket_latents is not None:
|
||||||
self._init_bucket_data(bucket_latents)
|
self._init_bucket_data(bucket_latents)
|
||||||
@ -186,6 +188,129 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
extra_args,
|
extra_args,
|
||||||
dataset_size,
|
dataset_size,
|
||||||
bwd=True,
|
bwd=True,
|
||||||
|
):
|
||||||
|
if self.dev_run:
|
||||||
|
return self._fwd_bwd_dev(
|
||||||
|
model_wrap, batch_sigmas, batch_noise, batch_latent,
|
||||||
|
cond, indicies, extra_args, dataset_size, bwd,
|
||||||
|
)
|
||||||
|
return self._fwd_bwd_impl(
|
||||||
|
model_wrap, batch_sigmas, batch_noise, batch_latent,
|
||||||
|
cond, indicies, extra_args, dataset_size, bwd,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _vram_info(tag):
|
||||||
|
"""Log VRAM usage from both PyTorch allocator and actual GPU."""
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
allocated = torch.cuda.memory_allocated()
|
||||||
|
reserved = torch.cuda.memory_reserved()
|
||||||
|
peak = torch.cuda.max_memory_allocated()
|
||||||
|
free, total = torch.cuda.mem_get_info()
|
||||||
|
gpu_used = total - free
|
||||||
|
logging.info(
|
||||||
|
f"[DevRun] {tag}\n"
|
||||||
|
f" PyTorch allocated: {allocated / 1024**2:.1f} MB | "
|
||||||
|
f"reserved: {reserved / 1024**2:.1f} MB | "
|
||||||
|
f"peak allocated: {peak / 1024**2:.1f} MB\n"
|
||||||
|
f" GPU real usage: {gpu_used / 1024**2:.1f} MB / {total / 1024**2:.1f} MB | "
|
||||||
|
f"non-PyTorch: {(gpu_used - reserved) / 1024**2:.1f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fwd_bwd_dev(
|
||||||
|
self,
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
cond,
|
||||||
|
indicies,
|
||||||
|
extra_args,
|
||||||
|
dataset_size,
|
||||||
|
bwd,
|
||||||
|
):
|
||||||
|
"""Wraps fwd_bwd with CUDA memory profiling for dev_run mode."""
|
||||||
|
output_dir = folder_paths.get_output_directory()
|
||||||
|
snapshot_path = os.path.join(output_dir, "dev_run_memory_snapshot.pkl")
|
||||||
|
fwd_args = (model_wrap, batch_sigmas, batch_noise, batch_latent,
|
||||||
|
cond, indicies, extra_args, dataset_size)
|
||||||
|
|
||||||
|
# ── Phase 0: no_grad forward-only reference ──
|
||||||
|
logging.info("[DevRun] ═══ Phase 0: no_grad forward-only reference ═══")
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
self._vram_info("Before no_grad fwd")
|
||||||
|
with torch.no_grad():
|
||||||
|
self._fwd_bwd_impl(*fwd_args, bwd=False)
|
||||||
|
self._vram_info("After no_grad fwd (activations freed)")
|
||||||
|
logging.info("[DevRun] ═══ End Phase 0 ═══\n")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# ── Phase 1: forward pass (with grad, no backward) ──
|
||||||
|
logging.info("[DevRun] ═══ Phase 1: forward pass (with grad) ═══")
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
self._vram_info("Before fwd")
|
||||||
|
|
||||||
|
# Record memory history with Python-only stacks (works on Windows)
|
||||||
|
torch.cuda.memory._record_memory_history(max_entries=100000, stacks="python")
|
||||||
|
|
||||||
|
# Inline the forward part of _fwd_bwd_impl so we can measure before backward
|
||||||
|
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
batch_sigmas, batch_noise, batch_latent, False
|
||||||
|
)
|
||||||
|
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
torch.zeros_like(batch_sigmas),
|
||||||
|
torch.zeros_like(batch_noise),
|
||||||
|
batch_latent,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
model_wrap.conds["positive"] = [cond[i] for i in indicies]
|
||||||
|
batch_extra_args = make_batch_extra_option_dict(
|
||||||
|
extra_args, indicies, full_size=dataset_size
|
||||||
|
)
|
||||||
|
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||||
|
x0_pred = model_wrap(
|
||||||
|
xt.requires_grad_(True),
|
||||||
|
batch_sigmas.requires_grad_(True),
|
||||||
|
**batch_extra_args,
|
||||||
|
)
|
||||||
|
loss = self.loss_fn(x0_pred.float(), x0.float())
|
||||||
|
|
||||||
|
self._vram_info("After fwd (autograd graph alive)")
|
||||||
|
|
||||||
|
# ── Phase 2: backward pass ──
|
||||||
|
logging.info("[DevRun] ═══ Phase 2: backward pass ═══")
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
if bwd:
|
||||||
|
bwd_loss = loss / self.grad_acc
|
||||||
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.scale(bwd_loss).backward()
|
||||||
|
else:
|
||||||
|
bwd_loss.backward()
|
||||||
|
|
||||||
|
self._vram_info("After bwd (grads computed)")
|
||||||
|
|
||||||
|
# ── Dump snapshot ──
|
||||||
|
torch.cuda.memory._dump_snapshot(snapshot_path)
|
||||||
|
torch.cuda.memory._record_memory_history(enabled=None)
|
||||||
|
logging.info(
|
||||||
|
f"[DevRun] Memory snapshot saved to: {snapshot_path}\n"
|
||||||
|
f" → Visualize at https://pytorch.org/memory_viz (drag in the .pkl file)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _fwd_bwd_impl(
|
||||||
|
self,
|
||||||
|
model_wrap,
|
||||||
|
batch_sigmas,
|
||||||
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
cond,
|
||||||
|
indicies,
|
||||||
|
extra_args,
|
||||||
|
dataset_size,
|
||||||
|
bwd=True,
|
||||||
):
|
):
|
||||||
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
batch_sigmas, batch_noise, batch_latent, False
|
batch_sigmas, batch_noise, batch_latent, False
|
||||||
@ -1132,6 +1257,12 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
bypass_mode = bypass_mode[0]
|
bypass_mode = bypass_mode[0]
|
||||||
|
|
||||||
|
# Dev run mode (--dev-mode): force batch_size=1, steps=1 for memory profiling
|
||||||
|
dev_run = args.dev_mode
|
||||||
|
if dev_run:
|
||||||
|
logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling")
|
||||||
|
batch_size = 1
|
||||||
|
steps = 1
|
||||||
comfy.model_management.training_fp8_bwd = quantized_backward
|
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||||
|
|
||||||
# Process latents based on mode
|
# Process latents based on mode
|
||||||
@ -1240,6 +1371,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
bucket_latents=latents,
|
bucket_latents=latents,
|
||||||
use_grad_scaler=use_grad_scaler,
|
use_grad_scaler=use_grad_scaler,
|
||||||
|
dev_run=dev_run,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_sampler = TrainSampler(
|
train_sampler = TrainSampler(
|
||||||
@ -1253,6 +1385,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
real_dataset=latents if multi_res else None,
|
real_dataset=latents if multi_res else None,
|
||||||
use_grad_scaler=use_grad_scaler,
|
use_grad_scaler=use_grad_scaler,
|
||||||
|
dev_run=dev_run,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user