diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index f9597de5b..5e764bb46 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: return out.to(dtype=torch.float32, device=pos.device) +def _apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + + return x_out.reshape(*x.shape).type_as(x) + + +def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + + try: import comfy.quant_ops - apply_rope = comfy.quant_ops.ck.apply_rope - apply_rope1 = comfy.quant_ops.ck.apply_rope1 + q_apply_rope = comfy.quant_ops.ck.apply_rope + q_apply_rope1 = comfy.quant_ops.ck.apply_rope1 + def apply_rope(xq, xk, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope(xq, xk, freqs_cis) + else: + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + def apply_rope1(x, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope1(x, freqs_cis) + else: + return q_apply_rope1(x, freqs_cis) except: logging.warning("No comfy kitchen, using old apply_rope functions.") - def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) - - return x_out.reshape(*x.shape).type_as(x) - - def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + apply_rope = _apply_rope + apply_rope1 = _apply_rope1 diff --git a/comfy/model_management.py b/comfy/model_management.py index 6b1166b94..af266348a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -54,6 +54,11 @@ cpu_state = CPUState.GPU total_vram = 0 + +# Training Related State +in_training = False + + def get_supported_float8_types(): float8_types = [] try: diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 9134e6d71..1f75f2ba7 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds): minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) return memory_required, minimum_memory_required -def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _prepare_sampling, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) ) - return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load) + return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload) -def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? - memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) - comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load) + if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load + memory_required = 1e20 + minimum_memory_required = None + else: + memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) + memory_required += inference_memory + minimum_memory_required += inference_memory + comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) real_model = model.model return real_model, conds, models diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py index d4aaf98ca..b9d5ec7d9 100644 --- a/comfy/weight_adapter/bypass.py +++ b/comfy/weight_adapter/bypass.py @@ -21,6 +21,7 @@ from typing import Optional, Union import torch import torch.nn as nn +import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase from comfy.patcher_extension import PatcherInjection @@ -181,18 +182,21 @@ class BypassForwardHook: ) return # Already injected - # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward - device = None + # Move adapter weights to compute device (GPU) + # Use get_torch_device() instead of module.weight.device because + # with offloading, module weights may be on CPU while compute happens on GPU + device = comfy.model_management.get_torch_device() + + # Get dtype from module weight if available dtype = None if hasattr(self.module, "weight") and self.module.weight is not None: - device = self.module.weight.device dtype = self.module.weight.dtype - elif hasattr(self.module, "W_q"): # Quantized layers might use different attr - device = self.module.W_q.device - dtype = self.module.W_q.dtype - if device is not None: - self._move_adapter_weights_to_device(device, dtype) + # Only use dtype if it's a standard float type, not quantized + if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16): + dtype = None + + self._move_adapter_weights_to_device(device, dtype) self.original_forward = self.module.forward self.module.forward = self._bypass_forward diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 024a89391..18455c882 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -4,6 +4,7 @@ import os import numpy as np import safetensors import torch +import torch.nn as nn import torch.utils.checkpoint from tqdm.auto import trange from PIL import Image, ImageDraw, ImageFont @@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): """ CFGGuider with modifications for training specific logic """ + + def __init__(self, *args, offloading=False, **kwargs): + super().__init__(*args, **kwargs) + self.offloading = offloading + def outer_sample( self, noise, @@ -45,7 +51,8 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): noise.shape, self.conds, self.model_options, - force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded + force_full_load=False, + force_offload=self.offloading, ) ) device = self.model_patcher.load_device @@ -404,16 +411,97 @@ def find_all_highest_child_module_with_forward( return result -def patch(m): +def find_modules_at_depth( + model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None +) -> list[nn.Module]: + """ + Find modules at a specific depth level for gradient checkpointing. + + Args: + model: The model to search + depth: Target depth level (1 = top-level blocks, 2 = their children, etc.) + result: Accumulator for results + current_depth: Current recursion depth + name: Current module name for logging + + Returns: + List of modules at the target depth + """ + if result is None: + result = [] + name = name or "root" + + # Skip container modules (they don't have meaningful forward) + is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict)) + has_forward = hasattr(model, "forward") and not is_container + + if has_forward: + current_depth += 1 + if current_depth == depth: + result.append(model) + logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})") + return result + + # Recurse into children + for next_name, child in model.named_children(): + find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}") + + return result + + +class OffloadCheckpointFunction(torch.autograd.Function): + """ + Gradient checkpointing that works with weight offloading. + + Forward: no_grad -> compute -> weights can be freed + Backward: enable_grad -> recompute -> backward -> weights can be freed + + For single input, single output modules (Linear, Conv*). + """ + + @staticmethod + def forward(ctx, x: torch.Tensor, forward_fn): + ctx.save_for_backward(x) + ctx.forward_fn = forward_fn + with torch.no_grad(): + return forward_fn(x) + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + x, = ctx.saved_tensors + forward_fn = ctx.forward_fn + + # Clear context early + ctx.forward_fn = None + + with torch.enable_grad(): + x_detached = x.detach().requires_grad_(True) + y = forward_fn(x_detached) + y.backward(grad_out) + grad_x = x_detached.grad + + # Explicit cleanup + del y, x_detached, forward_fn + + return grad_x, None + + +def patch(m, offloading=False): if not hasattr(m, "forward"): return org_forward = m.forward - def fwd(args, kwargs): - return org_forward(*args, **kwargs) + # Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output) + if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): + def checkpointing_fwd(x): + return OffloadCheckpointFunction.apply(x, org_forward) + # Branch 2: Others -> standard checkpoint + else: + def fwd(args, kwargs): + return org_forward(*args, **kwargs) - def checkpointing_fwd(*args, **kwargs): - return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) + def checkpointing_fwd(*args, **kwargs): + return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) m.org_forward = org_forward m.forward = checkpointing_fwd @@ -936,6 +1024,18 @@ class TrainLoraNode(io.ComfyNode): default=True, tooltip="Use gradient checkpointing for training.", ), + io.Int.Input( + "checkpoint_depth", + default=1, + min=1, + max=5, + tooltip="Depth level for gradient checkpointing.", + ), + io.Int.Input( + "offloading", + default=False, + tooltip="Depth level for gradient checkpointing.", + ), io.Combo.Input( "existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], @@ -982,6 +1082,8 @@ class TrainLoraNode(io.ComfyNode): lora_dtype, algorithm, gradient_checkpointing, + checkpoint_depth, + offloading, existing_lora, bucket_mode, bypass_mode, @@ -1000,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode): lora_dtype = lora_dtype[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] + checkpoint_depth = checkpoint_depth[0] existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] bypass_mode = bypass_mode[0] @@ -1054,16 +1157,18 @@ class TrainLoraNode(io.ComfyNode): # Setup gradient checkpointing if gradient_checkpointing: - for m in find_all_highest_child_module_with_forward( - mp.model.diffusion_model - ): - patch(m) + modules_to_patch = find_modules_at_depth( + mp.model.diffusion_model, depth=checkpoint_depth + ) + logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}") + for m in modules_to_patch: + patch(m, offloading=offloading) torch.cuda.empty_cache() # With force_full_load=False we should be able to have offloading # But for offloading in training we need custom AutoGrad hooks for fwd/bwd comfy.model_management.load_models_gpu( - [mp], memory_required=1e20, force_full_load=True + [mp], memory_required=1e20, force_full_load=False ) torch.cuda.empty_cache() @@ -1100,7 +1205,7 @@ class TrainLoraNode(io.ComfyNode): ) # Setup guider - guider = TrainGuider(mp) + guider = TrainGuider(mp, offloading) guider.set_conds(positive) # Inject bypass hooks if bypass mode is enabled @@ -1113,6 +1218,7 @@ class TrainLoraNode(io.ComfyNode): # Run training loop try: + comfy.model_management.in_training = True _run_training_loop( guider, train_sampler, @@ -1123,6 +1229,7 @@ class TrainLoraNode(io.ComfyNode): multi_res, ) finally: + comfy.model_management.in_training = False # Eject bypass hooks if they were injected if bypass_injections is not None: for injection in bypass_injections: