mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 11:32:31 +08:00
offloading implementation in training node
This commit is contained in:
parent
ec61c02bf6
commit
20bd2c0236
@ -4,6 +4,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|||||||
"""
|
"""
|
||||||
CFGGuider with modifications for training specific logic
|
CFGGuider with modifications for training specific logic
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, offloading=False, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.offloading = offloading
|
||||||
|
|
||||||
def outer_sample(
|
def outer_sample(
|
||||||
self,
|
self,
|
||||||
noise,
|
noise,
|
||||||
@ -45,7 +51,8 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|||||||
noise.shape,
|
noise.shape,
|
||||||
self.conds,
|
self.conds,
|
||||||
self.model_options,
|
self.model_options,
|
||||||
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
|
force_full_load=False,
|
||||||
|
force_offload=self.offloading,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
@ -404,16 +411,97 @@ def find_all_highest_child_module_with_forward(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def patch(m):
|
def find_modules_at_depth(
|
||||||
|
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
|
||||||
|
) -> list[nn.Module]:
|
||||||
|
"""
|
||||||
|
Find modules at a specific depth level for gradient checkpointing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to search
|
||||||
|
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
|
||||||
|
result: Accumulator for results
|
||||||
|
current_depth: Current recursion depth
|
||||||
|
name: Current module name for logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of modules at the target depth
|
||||||
|
"""
|
||||||
|
if result is None:
|
||||||
|
result = []
|
||||||
|
name = name or "root"
|
||||||
|
|
||||||
|
# Skip container modules (they don't have meaningful forward)
|
||||||
|
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
|
||||||
|
has_forward = hasattr(model, "forward") and not is_container
|
||||||
|
|
||||||
|
if has_forward:
|
||||||
|
current_depth += 1
|
||||||
|
if current_depth == depth:
|
||||||
|
result.append(model)
|
||||||
|
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Recurse into children
|
||||||
|
for next_name, child in model.named_children():
|
||||||
|
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadCheckpointFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Gradient checkpointing that works with weight offloading.
|
||||||
|
|
||||||
|
Forward: no_grad -> compute -> weights can be freed
|
||||||
|
Backward: enable_grad -> recompute -> backward -> weights can be freed
|
||||||
|
|
||||||
|
For single input, single output modules (Linear, Conv*).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: torch.Tensor, forward_fn):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
ctx.forward_fn = forward_fn
|
||||||
|
with torch.no_grad():
|
||||||
|
return forward_fn(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out: torch.Tensor):
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
forward_fn = ctx.forward_fn
|
||||||
|
|
||||||
|
# Clear context early
|
||||||
|
ctx.forward_fn = None
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
x_detached = x.detach().requires_grad_(True)
|
||||||
|
y = forward_fn(x_detached)
|
||||||
|
y.backward(grad_out)
|
||||||
|
grad_x = x_detached.grad
|
||||||
|
|
||||||
|
# Explicit cleanup
|
||||||
|
del y, x_detached, forward_fn
|
||||||
|
|
||||||
|
return grad_x, None
|
||||||
|
|
||||||
|
|
||||||
|
def patch(m, offloading=False):
|
||||||
if not hasattr(m, "forward"):
|
if not hasattr(m, "forward"):
|
||||||
return
|
return
|
||||||
org_forward = m.forward
|
org_forward = m.forward
|
||||||
|
|
||||||
def fwd(args, kwargs):
|
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
|
||||||
return org_forward(*args, **kwargs)
|
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||||
|
def checkpointing_fwd(x):
|
||||||
|
return OffloadCheckpointFunction.apply(x, org_forward)
|
||||||
|
# Branch 2: Others -> standard checkpoint
|
||||||
|
else:
|
||||||
|
def fwd(args, kwargs):
|
||||||
|
return org_forward(*args, **kwargs)
|
||||||
|
|
||||||
def checkpointing_fwd(*args, **kwargs):
|
def checkpointing_fwd(*args, **kwargs):
|
||||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||||
|
|
||||||
m.org_forward = org_forward
|
m.org_forward = org_forward
|
||||||
m.forward = checkpointing_fwd
|
m.forward = checkpointing_fwd
|
||||||
@ -936,6 +1024,18 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default=True,
|
default=True,
|
||||||
tooltip="Use gradient checkpointing for training.",
|
tooltip="Use gradient checkpointing for training.",
|
||||||
),
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"checkpoint_depth",
|
||||||
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=5,
|
||||||
|
tooltip="Depth level for gradient checkpointing.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"offloading",
|
||||||
|
default=False,
|
||||||
|
tooltip="Depth level for gradient checkpointing.",
|
||||||
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"existing_lora",
|
"existing_lora",
|
||||||
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
||||||
@ -982,6 +1082,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype,
|
lora_dtype,
|
||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
|
checkpoint_depth,
|
||||||
|
offloading,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
bucket_mode,
|
bucket_mode,
|
||||||
bypass_mode,
|
bypass_mode,
|
||||||
@ -1000,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
|
checkpoint_depth = checkpoint_depth[0]
|
||||||
existing_lora = existing_lora[0]
|
existing_lora = existing_lora[0]
|
||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
bypass_mode = bypass_mode[0]
|
bypass_mode = bypass_mode[0]
|
||||||
@ -1054,16 +1157,18 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Setup gradient checkpointing
|
# Setup gradient checkpointing
|
||||||
if gradient_checkpointing:
|
if gradient_checkpointing:
|
||||||
for m in find_all_highest_child_module_with_forward(
|
modules_to_patch = find_modules_at_depth(
|
||||||
mp.model.diffusion_model
|
mp.model.diffusion_model, depth=checkpoint_depth
|
||||||
):
|
)
|
||||||
patch(m)
|
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
|
||||||
|
for m in modules_to_patch:
|
||||||
|
patch(m, offloading=offloading)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# With force_full_load=False we should be able to have offloading
|
# With force_full_load=False we should be able to have offloading
|
||||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
[mp], memory_required=1e20, force_full_load=True
|
[mp], memory_required=1e20, force_full_load=False
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -1100,7 +1205,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
guider = TrainGuider(mp)
|
guider = TrainGuider(mp, offloading)
|
||||||
guider.set_conds(positive)
|
guider.set_conds(positive)
|
||||||
|
|
||||||
# Inject bypass hooks if bypass mode is enabled
|
# Inject bypass hooks if bypass mode is enabled
|
||||||
@ -1113,6 +1218,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Run training loop
|
# Run training loop
|
||||||
try:
|
try:
|
||||||
|
comfy.model_management.in_training = True
|
||||||
_run_training_loop(
|
_run_training_loop(
|
||||||
guider,
|
guider,
|
||||||
train_sampler,
|
train_sampler,
|
||||||
@ -1123,6 +1229,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
multi_res,
|
multi_res,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
comfy.model_management.in_training = False
|
||||||
# Eject bypass hooks if they were injected
|
# Eject bypass hooks if they were injected
|
||||||
if bypass_injections is not None:
|
if bypass_injections is not None:
|
||||||
for injection in bypass_injections:
|
for injection in bypass_injections:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user