Add dev run mode for vram profiling or debugging

This commit is contained in:
Kohaku-Blueleaf 2026-03-23 19:32:21 +08:00
parent 6265a239f3
commit 3a82bd15e7

View File

@ -140,6 +140,7 @@ class TrainSampler(comfy.samplers.Sampler):
real_dataset=None,
bucket_latents=None,
use_grad_scaler=False,
dev_run=False,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
@ -156,6 +157,7 @@ class TrainSampler(comfy.samplers.Sampler):
)
# GradScaler for fp16 training
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
self.dev_run = dev_run
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
@ -186,6 +188,129 @@ class TrainSampler(comfy.samplers.Sampler):
extra_args,
dataset_size,
bwd=True,
):
if self.dev_run:
return self._fwd_bwd_dev(
model_wrap, batch_sigmas, batch_noise, batch_latent,
cond, indicies, extra_args, dataset_size, bwd,
)
return self._fwd_bwd_impl(
model_wrap, batch_sigmas, batch_noise, batch_latent,
cond, indicies, extra_args, dataset_size, bwd,
)
@staticmethod
def _vram_info(tag):
"""Log VRAM usage from both PyTorch allocator and actual GPU."""
torch.cuda.synchronize()
allocated = torch.cuda.memory_allocated()
reserved = torch.cuda.memory_reserved()
peak = torch.cuda.max_memory_allocated()
free, total = torch.cuda.mem_get_info()
gpu_used = total - free
logging.info(
f"[DevRun] {tag}\n"
f" PyTorch allocated: {allocated / 1024**2:.1f} MB | "
f"reserved: {reserved / 1024**2:.1f} MB | "
f"peak allocated: {peak / 1024**2:.1f} MB\n"
f" GPU real usage: {gpu_used / 1024**2:.1f} MB / {total / 1024**2:.1f} MB | "
f"non-PyTorch: {(gpu_used - reserved) / 1024**2:.1f} MB"
)
def _fwd_bwd_dev(
self,
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd,
):
"""Wraps fwd_bwd with CUDA memory profiling for dev_run mode."""
output_dir = folder_paths.get_output_directory()
snapshot_path = os.path.join(output_dir, "dev_run_memory_snapshot.pkl")
fwd_args = (model_wrap, batch_sigmas, batch_noise, batch_latent,
cond, indicies, extra_args, dataset_size)
# ── Phase 0: no_grad forward-only reference ──
logging.info("[DevRun] ═══ Phase 0: no_grad forward-only reference ═══")
torch.cuda.reset_peak_memory_stats()
self._vram_info("Before no_grad fwd")
with torch.no_grad():
self._fwd_bwd_impl(*fwd_args, bwd=False)
self._vram_info("After no_grad fwd (activations freed)")
logging.info("[DevRun] ═══ End Phase 0 ═══\n")
torch.cuda.empty_cache()
# ── Phase 1: forward pass (with grad, no backward) ──
logging.info("[DevRun] ═══ Phase 1: forward pass (with grad) ═══")
torch.cuda.reset_peak_memory_stats()
self._vram_info("Before fwd")
# Record memory history with Python-only stacks (works on Windows)
torch.cuda.memory._record_memory_history(max_entries=100000, stacks="python")
# Inline the forward part of _fwd_bwd_impl so we can measure before backward
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas, batch_noise, batch_latent, False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False,
)
model_wrap.conds["positive"] = [cond[i] for i in indicies]
batch_extra_args = make_batch_extra_option_dict(
extra_args, indicies, full_size=dataset_size
)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(
xt.requires_grad_(True),
batch_sigmas.requires_grad_(True),
**batch_extra_args,
)
loss = self.loss_fn(x0_pred.float(), x0.float())
self._vram_info("After fwd (autograd graph alive)")
# ── Phase 2: backward pass ──
logging.info("[DevRun] ═══ Phase 2: backward pass ═══")
torch.cuda.reset_peak_memory_stats()
if bwd:
bwd_loss = loss / self.grad_acc
if self.grad_scaler is not None:
self.grad_scaler.scale(bwd_loss).backward()
else:
bwd_loss.backward()
self._vram_info("After bwd (grads computed)")
# ── Dump snapshot ──
torch.cuda.memory._dump_snapshot(snapshot_path)
torch.cuda.memory._record_memory_history(enabled=None)
logging.info(
f"[DevRun] Memory snapshot saved to: {snapshot_path}\n"
f" → Visualize at https://pytorch.org/memory_viz (drag in the .pkl file)"
)
return loss
def _fwd_bwd_impl(
self,
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
):
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas, batch_noise, batch_latent, False
@ -1069,6 +1194,11 @@ class TrainLoraNode(io.ComfyNode):
default=False,
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
),
io.Boolean.Input(
"dev_run",
default=False,
tooltip="Developer profiling mode. Forces batch_size=1, steps=1, records CUDA memory history during fwd_bwd, and exports a memory snapshot to the output folder.",
),
],
outputs=[
io.Custom("LORA_MODEL").Output(
@ -1104,6 +1234,7 @@ class TrainLoraNode(io.ComfyNode):
existing_lora,
bucket_mode,
bypass_mode,
dev_run,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
@ -1124,6 +1255,13 @@ class TrainLoraNode(io.ComfyNode):
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
dev_run = dev_run[0]
# Dev run mode: force batch_size=1, steps=1 for memory profiling
if dev_run:
logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling")
batch_size = 1
steps = 1
# Process latents based on mode
if bucket_mode:
@ -1228,6 +1366,7 @@ class TrainLoraNode(io.ComfyNode):
training_dtype=dtype,
bucket_latents=latents,
use_grad_scaler=use_grad_scaler,
dev_run=dev_run,
)
else:
train_sampler = TrainSampler(
@ -1241,6 +1380,7 @@ class TrainLoraNode(io.ComfyNode):
training_dtype=dtype,
real_dataset=latents if multi_res else None,
use_grad_scaler=use_grad_scaler,
dev_run=dev_run,
)
# Setup guider