Add dev run mode for vram profiling or debugging

This commit is contained in:
Kohaku-Blueleaf 2026-03-23 19:32:21 +08:00
parent 6265a239f3
commit 3a82bd15e7

View File

@ -140,6 +140,7 @@ class TrainSampler(comfy.samplers.Sampler):
real_dataset=None, real_dataset=None,
bucket_latents=None, bucket_latents=None,
use_grad_scaler=False, use_grad_scaler=False,
dev_run=False,
): ):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optimizer self.optimizer = optimizer
@ -156,6 +157,7 @@ class TrainSampler(comfy.samplers.Sampler):
) )
# GradScaler for fp16 training # GradScaler for fp16 training
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
self.dev_run = dev_run
# Precompute bucket offsets and weights for sampling # Precompute bucket offsets and weights for sampling
if bucket_latents is not None: if bucket_latents is not None:
self._init_bucket_data(bucket_latents) self._init_bucket_data(bucket_latents)
@ -186,6 +188,129 @@ class TrainSampler(comfy.samplers.Sampler):
extra_args, extra_args,
dataset_size, dataset_size,
bwd=True, bwd=True,
):
if self.dev_run:
return self._fwd_bwd_dev(
model_wrap, batch_sigmas, batch_noise, batch_latent,
cond, indicies, extra_args, dataset_size, bwd,
)
return self._fwd_bwd_impl(
model_wrap, batch_sigmas, batch_noise, batch_latent,
cond, indicies, extra_args, dataset_size, bwd,
)
@staticmethod
def _vram_info(tag):
"""Log VRAM usage from both PyTorch allocator and actual GPU."""
torch.cuda.synchronize()
allocated = torch.cuda.memory_allocated()
reserved = torch.cuda.memory_reserved()
peak = torch.cuda.max_memory_allocated()
free, total = torch.cuda.mem_get_info()
gpu_used = total - free
logging.info(
f"[DevRun] {tag}\n"
f" PyTorch allocated: {allocated / 1024**2:.1f} MB | "
f"reserved: {reserved / 1024**2:.1f} MB | "
f"peak allocated: {peak / 1024**2:.1f} MB\n"
f" GPU real usage: {gpu_used / 1024**2:.1f} MB / {total / 1024**2:.1f} MB | "
f"non-PyTorch: {(gpu_used - reserved) / 1024**2:.1f} MB"
)
def _fwd_bwd_dev(
self,
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd,
):
"""Wraps fwd_bwd with CUDA memory profiling for dev_run mode."""
output_dir = folder_paths.get_output_directory()
snapshot_path = os.path.join(output_dir, "dev_run_memory_snapshot.pkl")
fwd_args = (model_wrap, batch_sigmas, batch_noise, batch_latent,
cond, indicies, extra_args, dataset_size)
# ── Phase 0: no_grad forward-only reference ──
logging.info("[DevRun] ═══ Phase 0: no_grad forward-only reference ═══")
torch.cuda.reset_peak_memory_stats()
self._vram_info("Before no_grad fwd")
with torch.no_grad():
self._fwd_bwd_impl(*fwd_args, bwd=False)
self._vram_info("After no_grad fwd (activations freed)")
logging.info("[DevRun] ═══ End Phase 0 ═══\n")
torch.cuda.empty_cache()
# ── Phase 1: forward pass (with grad, no backward) ──
logging.info("[DevRun] ═══ Phase 1: forward pass (with grad) ═══")
torch.cuda.reset_peak_memory_stats()
self._vram_info("Before fwd")
# Record memory history with Python-only stacks (works on Windows)
torch.cuda.memory._record_memory_history(max_entries=100000, stacks="python")
# Inline the forward part of _fwd_bwd_impl so we can measure before backward
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas, batch_noise, batch_latent, False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False,
)
model_wrap.conds["positive"] = [cond[i] for i in indicies]
batch_extra_args = make_batch_extra_option_dict(
extra_args, indicies, full_size=dataset_size
)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(
xt.requires_grad_(True),
batch_sigmas.requires_grad_(True),
**batch_extra_args,
)
loss = self.loss_fn(x0_pred.float(), x0.float())
self._vram_info("After fwd (autograd graph alive)")
# ── Phase 2: backward pass ──
logging.info("[DevRun] ═══ Phase 2: backward pass ═══")
torch.cuda.reset_peak_memory_stats()
if bwd:
bwd_loss = loss / self.grad_acc
if self.grad_scaler is not None:
self.grad_scaler.scale(bwd_loss).backward()
else:
bwd_loss.backward()
self._vram_info("After bwd (grads computed)")
# ── Dump snapshot ──
torch.cuda.memory._dump_snapshot(snapshot_path)
torch.cuda.memory._record_memory_history(enabled=None)
logging.info(
f"[DevRun] Memory snapshot saved to: {snapshot_path}\n"
f" → Visualize at https://pytorch.org/memory_viz (drag in the .pkl file)"
)
return loss
def _fwd_bwd_impl(
self,
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
): ):
xt = model_wrap.inner_model.model_sampling.noise_scaling( xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas, batch_noise, batch_latent, False batch_sigmas, batch_noise, batch_latent, False
@ -1069,6 +1194,11 @@ class TrainLoraNode(io.ComfyNode):
default=False, default=False,
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.", tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
), ),
io.Boolean.Input(
"dev_run",
default=False,
tooltip="Developer profiling mode. Forces batch_size=1, steps=1, records CUDA memory history during fwd_bwd, and exports a memory snapshot to the output folder.",
),
], ],
outputs=[ outputs=[
io.Custom("LORA_MODEL").Output( io.Custom("LORA_MODEL").Output(
@ -1104,6 +1234,7 @@ class TrainLoraNode(io.ComfyNode):
existing_lora, existing_lora,
bucket_mode, bucket_mode,
bypass_mode, bypass_mode,
dev_run,
): ):
# Extract scalars from lists (due to is_input_list=True) # Extract scalars from lists (due to is_input_list=True)
model = model[0] model = model[0]
@ -1124,6 +1255,13 @@ class TrainLoraNode(io.ComfyNode):
existing_lora = existing_lora[0] existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0] bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0] bypass_mode = bypass_mode[0]
dev_run = dev_run[0]
# Dev run mode: force batch_size=1, steps=1 for memory profiling
if dev_run:
logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling")
batch_size = 1
steps = 1
# Process latents based on mode # Process latents based on mode
if bucket_mode: if bucket_mode:
@ -1228,6 +1366,7 @@ class TrainLoraNode(io.ComfyNode):
training_dtype=dtype, training_dtype=dtype,
bucket_latents=latents, bucket_latents=latents,
use_grad_scaler=use_grad_scaler, use_grad_scaler=use_grad_scaler,
dev_run=dev_run,
) )
else: else:
train_sampler = TrainSampler( train_sampler = TrainSampler(
@ -1241,6 +1380,7 @@ class TrainLoraNode(io.ComfyNode):
training_dtype=dtype, training_dtype=dtype,
real_dataset=latents if multi_res else None, real_dataset=latents if multi_res else None,
use_grad_scaler=use_grad_scaler, use_grad_scaler=use_grad_scaler,
dev_run=dev_run,
) )
# Setup guider # Setup guider