This commit is contained in:
Kohaku-Blueleaf 2026-02-02 21:45:20 +08:00 committed by GitHub
commit 3153378616
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 174 additions and 37 deletions

View File

@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.to(dtype=torch.float32, device=pos.device) return out.to(dtype=torch.float32, device=pos.device)
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
try: try:
import comfy.quant_ops import comfy.quant_ops
apply_rope = comfy.quant_ops.ck.apply_rope q_apply_rope = comfy.quant_ops.ck.apply_rope
apply_rope1 = comfy.quant_ops.ck.apply_rope1 q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return q_apply_rope1(x, freqs_cis)
except: except:
logging.warning("No comfy kitchen, using old apply_rope functions.") logging.warning("No comfy kitchen, using old apply_rope functions.")
def apply_rope1(x: Tensor, freqs_cis: Tensor): apply_rope = _apply_rope
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) apply_rope1 = _apply_rope1
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

View File

@ -54,6 +54,11 @@ cpu_state = CPUState.GPU
total_vram = 0 total_vram = 0
# Training Related State
in_training = False
def get_supported_float8_types(): def get_supported_float8_types():
float8_types = [] float8_types = []
try: try:

View File

@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor( executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling, _prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
) )
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load) return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options) models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update? models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load) memory_required = 1e20
minimum_memory_required = None
else:
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model real_model = model.model
return real_model, conds, models return real_model, conds, models

View File

@ -21,6 +21,7 @@ from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection from comfy.patcher_extension import PatcherInjection
@ -181,18 +182,21 @@ class BypassForwardHook:
) )
return # Already injected return # Already injected
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward # Move adapter weights to compute device (GPU)
device = None # Use get_torch_device() instead of module.weight.device because
# with offloading, module weights may be on CPU while compute happens on GPU
device = comfy.model_management.get_torch_device()
# Get dtype from module weight if available
dtype = None dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None: if hasattr(self.module, "weight") and self.module.weight is not None:
device = self.module.weight.device
dtype = self.module.weight.dtype dtype = self.module.weight.dtype
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
device = self.module.W_q.device
dtype = self.module.W_q.dtype
if device is not None: # Only use dtype if it's a standard float type, not quantized
self._move_adapter_weights_to_device(device, dtype) if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = None
self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward self.original_forward = self.module.forward
self.module.forward = self._bypass_forward self.module.forward = self._bypass_forward

View File

@ -4,6 +4,7 @@ import os
import numpy as np import numpy as np
import safetensors import safetensors
import torch import torch
import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from tqdm.auto import trange from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
""" """
CFGGuider with modifications for training specific logic CFGGuider with modifications for training specific logic
""" """
def __init__(self, *args, offloading=False, **kwargs):
super().__init__(*args, **kwargs)
self.offloading = offloading
def outer_sample( def outer_sample(
self, self,
noise, noise,
@ -45,7 +51,8 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape, noise.shape,
self.conds, self.conds,
self.model_options, self.model_options,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded force_full_load=False,
force_offload=self.offloading,
) )
) )
device = self.model_patcher.load_device device = self.model_patcher.load_device
@ -404,16 +411,97 @@ def find_all_highest_child_module_with_forward(
return result return result
def patch(m): def find_modules_at_depth(
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
) -> list[nn.Module]:
"""
Find modules at a specific depth level for gradient checkpointing.
Args:
model: The model to search
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
result: Accumulator for results
current_depth: Current recursion depth
name: Current module name for logging
Returns:
List of modules at the target depth
"""
if result is None:
result = []
name = name or "root"
# Skip container modules (they don't have meaningful forward)
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
has_forward = hasattr(model, "forward") and not is_container
if has_forward:
current_depth += 1
if current_depth == depth:
result.append(model)
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
return result
# Recurse into children
for next_name, child in model.named_children():
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
return result
class OffloadCheckpointFunction(torch.autograd.Function):
"""
Gradient checkpointing that works with weight offloading.
Forward: no_grad -> compute -> weights can be freed
Backward: enable_grad -> recompute -> backward -> weights can be freed
For single input, single output modules (Linear, Conv*).
"""
@staticmethod
def forward(ctx, x: torch.Tensor, forward_fn):
ctx.save_for_backward(x)
ctx.forward_fn = forward_fn
with torch.no_grad():
return forward_fn(x)
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
x, = ctx.saved_tensors
forward_fn = ctx.forward_fn
# Clear context early
ctx.forward_fn = None
with torch.enable_grad():
x_detached = x.detach().requires_grad_(True)
y = forward_fn(x_detached)
y.backward(grad_out)
grad_x = x_detached.grad
# Explicit cleanup
del y, x_detached, forward_fn
return grad_x, None
def patch(m, offloading=False):
if not hasattr(m, "forward"): if not hasattr(m, "forward"):
return return
org_forward = m.forward org_forward = m.forward
def fwd(args, kwargs): # Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
return org_forward(*args, **kwargs) if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
def checkpointing_fwd(x):
return OffloadCheckpointFunction.apply(x, org_forward)
# Branch 2: Others -> standard checkpoint
else:
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs): def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward m.org_forward = org_forward
m.forward = checkpointing_fwd m.forward = checkpointing_fwd
@ -936,6 +1024,18 @@ class TrainLoraNode(io.ComfyNode):
default=True, default=True,
tooltip="Use gradient checkpointing for training.", tooltip="Use gradient checkpointing for training.",
), ),
io.Int.Input(
"checkpoint_depth",
default=1,
min=1,
max=5,
tooltip="Depth level for gradient checkpointing.",
),
io.Int.Input(
"offloading",
default=False,
tooltip="Depth level for gradient checkpointing.",
),
io.Combo.Input( io.Combo.Input(
"existing_lora", "existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"], options=folder_paths.get_filename_list("loras") + ["[None]"],
@ -982,6 +1082,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype, lora_dtype,
algorithm, algorithm,
gradient_checkpointing, gradient_checkpointing,
checkpoint_depth,
offloading,
existing_lora, existing_lora,
bucket_mode, bucket_mode,
bypass_mode, bypass_mode,
@ -1000,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0] lora_dtype = lora_dtype[0]
algorithm = algorithm[0] algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0] gradient_checkpointing = gradient_checkpointing[0]
checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0] existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0] bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0] bypass_mode = bypass_mode[0]
@ -1054,16 +1157,18 @@ class TrainLoraNode(io.ComfyNode):
# Setup gradient checkpointing # Setup gradient checkpointing
if gradient_checkpointing: if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward( modules_to_patch = find_modules_at_depth(
mp.model.diffusion_model mp.model.diffusion_model, depth=checkpoint_depth
): )
patch(m) logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
for m in modules_to_patch:
patch(m, offloading=offloading)
torch.cuda.empty_cache() torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading # With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd # But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu( comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True [mp], memory_required=1e20, force_full_load=False
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -1100,7 +1205,7 @@ class TrainLoraNode(io.ComfyNode):
) )
# Setup guider # Setup guider
guider = TrainGuider(mp) guider = TrainGuider(mp, offloading)
guider.set_conds(positive) guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled # Inject bypass hooks if bypass mode is enabled
@ -1113,6 +1218,7 @@ class TrainLoraNode(io.ComfyNode):
# Run training loop # Run training loop
try: try:
comfy.model_management.in_training = True
_run_training_loop( _run_training_loop(
guider, guider,
train_sampler, train_sampler,
@ -1123,6 +1229,7 @@ class TrainLoraNode(io.ComfyNode):
multi_res, multi_res,
) )
finally: finally:
comfy.model_management.in_training = False
# Eject bypass hooks if they were injected # Eject bypass hooks if they were injected
if bypass_injections is not None: if bypass_injections is not None:
for injection in bypass_injections: for injection in bypass_injections: