offloading implementation in training node

This commit is contained in:
Kohaku-Blueleaf 2026-01-31 16:59:45 +08:00
parent ec61c02bf6
commit 20bd2c0236

View File

@ -4,6 +4,7 @@ import os
import numpy as np import numpy as np
import safetensors import safetensors
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from tqdm.auto import trange from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
""" """
CFGGuider with modifications for training specific logic CFGGuider with modifications for training specific logic
""" """
def __init__(self, *args, offloading=False, **kwargs):
super().__init__(*args, **kwargs)
self.offloading = offloading
def outer_sample( def outer_sample(
self, self,
noise, noise,
@ -45,7 +51,8 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape, noise.shape,
self.conds, self.conds,
self.model_options, self.model_options,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded force_full_load=False,
force_offload=self.offloading,
) )
) )
device = self.model_patcher.load_device device = self.model_patcher.load_device
@ -404,16 +411,97 @@ def find_all_highest_child_module_with_forward(
return result return result
def patch(m): def find_modules_at_depth(
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
) -> list[nn.Module]:
"""
Find modules at a specific depth level for gradient checkpointing.
Args:
model: The model to search
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
result: Accumulator for results
current_depth: Current recursion depth
name: Current module name for logging
Returns:
List of modules at the target depth
"""
if result is None:
result = []
name = name or "root"
# Skip container modules (they don't have meaningful forward)
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
has_forward = hasattr(model, "forward") and not is_container
if has_forward:
current_depth += 1
if current_depth == depth:
result.append(model)
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
return result
# Recurse into children
for next_name, child in model.named_children():
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
return result
class OffloadCheckpointFunction(torch.autograd.Function):
"""
Gradient checkpointing that works with weight offloading.
Forward: no_grad -> compute -> weights can be freed
Backward: enable_grad -> recompute -> backward -> weights can be freed
For single input, single output modules (Linear, Conv*).
"""
@staticmethod
def forward(ctx, x: torch.Tensor, forward_fn):
ctx.save_for_backward(x)
ctx.forward_fn = forward_fn
with torch.no_grad():
return forward_fn(x)
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
x, = ctx.saved_tensors
forward_fn = ctx.forward_fn
# Clear context early
ctx.forward_fn = None
with torch.enable_grad():
x_detached = x.detach().requires_grad_(True)
y = forward_fn(x_detached)
y.backward(grad_out)
grad_x = x_detached.grad
# Explicit cleanup
del y, x_detached, forward_fn
return grad_x, None
def patch(m, offloading=False):
if not hasattr(m, "forward"): if not hasattr(m, "forward"):
return return
org_forward = m.forward org_forward = m.forward
def fwd(args, kwargs): # Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
return org_forward(*args, **kwargs) if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
def checkpointing_fwd(x):
return OffloadCheckpointFunction.apply(x, org_forward)
# Branch 2: Others -> standard checkpoint
else:
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs): def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward m.org_forward = org_forward
m.forward = checkpointing_fwd m.forward = checkpointing_fwd
@ -936,6 +1024,18 @@ class TrainLoraNode(io.ComfyNode):
default=True, default=True,
tooltip="Use gradient checkpointing for training.", tooltip="Use gradient checkpointing for training.",
), ),
io.Int.Input(
"checkpoint_depth",
default=1,
min=1,
max=5,
tooltip="Depth level for gradient checkpointing.",
),
io.Int.Input(
"offloading",
default=False,
tooltip="Depth level for gradient checkpointing.",
),
io.Combo.Input( io.Combo.Input(
"existing_lora", "existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"], options=folder_paths.get_filename_list("loras") + ["[None]"],
@ -982,6 +1082,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype, lora_dtype,
algorithm, algorithm,
gradient_checkpointing, gradient_checkpointing,
checkpoint_depth,
offloading,
existing_lora, existing_lora,
bucket_mode, bucket_mode,
bypass_mode, bypass_mode,
@ -1000,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0] lora_dtype = lora_dtype[0]
algorithm = algorithm[0] algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0] gradient_checkpointing = gradient_checkpointing[0]
checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0] existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0] bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0] bypass_mode = bypass_mode[0]
@ -1054,16 +1157,18 @@ class TrainLoraNode(io.ComfyNode):
# Setup gradient checkpointing # Setup gradient checkpointing
if gradient_checkpointing: if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward( modules_to_patch = find_modules_at_depth(
mp.model.diffusion_model mp.model.diffusion_model, depth=checkpoint_depth
): )
patch(m) logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
for m in modules_to_patch:
patch(m, offloading=offloading)
torch.cuda.empty_cache() torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading # With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd # But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu( comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True [mp], memory_required=1e20, force_full_load=False
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -1100,7 +1205,7 @@ class TrainLoraNode(io.ComfyNode):
) )
# Setup guider # Setup guider
guider = TrainGuider(mp) guider = TrainGuider(mp, offloading)
guider.set_conds(positive) guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled # Inject bypass hooks if bypass mode is enabled
@ -1113,6 +1218,7 @@ class TrainLoraNode(io.ComfyNode):
# Run training loop # Run training loop
try: try:
comfy.model_management.in_training = True
_run_training_loop( _run_training_loop(
guider, guider,
train_sampler, train_sampler,
@ -1123,6 +1229,7 @@ class TrainLoraNode(io.ComfyNode):
multi_res, multi_res,
) )
finally: finally:
comfy.model_management.in_training = False
# Eject bypass hooks if they were injected # Eject bypass hooks if they were injected
if bypass_injections is not None: if bypass_injections is not None:
for injection in bypass_injections: for injection in bypass_injections: