From 4004af32907ac8e47ad39b3a7eb1738c77d8daa3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Dec 2025 17:24:55 +0800 Subject: [PATCH] Custom guider for correct offloading behavior --- comfy/sampler_helpers.py | 9 +-- comfy_extras/nodes_train.py | 107 +++++++++++++++++++++++++++++++----- 2 files changed, 97 insertions(+), 19 deletions(-) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index e46971afb..e158e8a84 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -122,20 +122,21 @@ def estimate_memory(model, noise_shape, conds): minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) return memory_required, minimum_memory_required -def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _prepare_sampling, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) ) - return executor.execute(model, noise_shape, conds, model_options=model_options) + return executor.execute(model, noise_shape, conds, model_options=model_options, skip_load_model=skip_load_model) -def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) - comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) + models_list = [model] if not skip_load_model else [] + comfy.model_management.load_models_gpu(models_list + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) real_model = model.model return real_model, conds, models diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index a24d3b199..71b307389 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont from typing_extensions import override import comfy.samplers +import comfy.sampler_helpers import comfy.sd import comfy.utils import comfy.model_management @@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar +class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): + """ + CFGGuider with modifications for training specific logic + """ + def outer_sample( + self, + noise, + latent_image, + sampler, + sigmas, + denoise_mask=None, + callback=None, + disable_pbar=False, + seed=None, + latent_shapes=None, + ): + self.inner_model, self.conds, self.loaded_models = ( + comfy.sampler_helpers.prepare_sampling( + self.model_patcher, + noise.shape, + self.conds, + self.model_options, + skip_load_model=True, # skip load model as we manage it in TrainLoraNode.execute() + ) + ) + device = self.model_patcher.load_device + + if denoise_mask is not None: + denoise_mask = comfy.sampler_helpers.prepare_mask( + denoise_mask, noise.shape, device + ) + + noise = noise.to(device) + latent_image = latent_image.to(device) + sigmas = sigmas.to(device) + comfy.samplers.cast_to_load_options( + self.model_options, device=device, dtype=self.model_patcher.model_dtype() + ) + + try: + self.model_patcher.pre_run() + output = self.inner_sample( + noise, + latent_image, + device, + sampler, + sigmas, + denoise_mask, + callback, + disable_pbar, + seed, + latent_shapes=latent_shapes, + ) + finally: + self.model_patcher.cleanup() + + comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) + del self.inner_model + del self.loaded_models + return output + + def make_batch_extra_option_dict(d, indicies, full_size=None): new_dict = {} for k, v in d.items(): @@ -77,7 +140,9 @@ class TrainSampler(comfy.samplers.Sampler): self.training_dtype = training_dtype self.real_dataset: list[torch.Tensor] | None = real_dataset # Bucket mode data - self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi) + self.bucket_latents: list[torch.Tensor] | None = ( + bucket_latents # list of (Bi, C, Hi, Wi) + ) # Precompute bucket offsets and weights for sampling if bucket_latents is not None: self._init_bucket_data(bucket_latents) @@ -511,7 +576,9 @@ def _load_existing_lora(existing_lora): return existing_weights, existing_steps -def _create_weight_adapter(module, module_name, existing_weights, algorithm, lora_dtype, rank): +def _create_weight_adapter( + module, module_name, existing_weights, algorithm, lora_dtype, rank +): """Create a weight adapter for a module with weight. Args: @@ -663,7 +730,9 @@ def _create_loss_function(loss_function_name): return torch.nn.SmoothL1Loss() -def _run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res): +def _run_training_loop( + guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res +): """Execute the training loop. Args: @@ -815,11 +884,6 @@ class TrainLoraNode(io.ComfyNode): default=False, tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", ), - io.Boolean.Input( - "offloading", - default=False, - tooltip="", - ), ], outputs=[ io.Model.Output( @@ -855,7 +919,6 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing, existing_lora, bucket_mode, - offloading, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -890,7 +953,9 @@ class TrainLoraNode(io.ComfyNode): mp.set_model_compute_dtype(dtype) # Prepare latents and compute counts - latents, num_images, multi_res = _prepare_latents_and_count(latents, dtype, bucket_mode) + latents, num_images, multi_res = _prepare_latents_and_count( + latents, dtype, bucket_mode + ) # Validate and expand conditioning positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) @@ -905,7 +970,9 @@ class TrainLoraNode(io.ComfyNode): ) # Create optimizer and loss function - optimizer = _create_optimizer(optimizer_name, lora_sd.values(), learning_rate) + optimizer = _create_optimizer( + optimizer_name, lora_sd.values(), learning_rate + ) criterion = _create_loss_function(loss_function_name) # Setup gradient checkpointing @@ -918,8 +985,10 @@ class TrainLoraNode(io.ComfyNode): # Setup models for training mp.model.requires_grad_(False) torch.cuda.empty_cache() + # With force_full_load=False we should be able to have offloading + # But for offloading in training we need custom AutoGrad hooks for fwd/bwd comfy.model_management.load_models_gpu( - [mp], memory_required=1e20, force_full_load=not offloading + [mp], memory_required=1e20, force_full_load=True ) torch.cuda.empty_cache() @@ -956,12 +1025,20 @@ class TrainLoraNode(io.ComfyNode): ) # Setup guider - guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) + guider = TrainGuider(mp) guider.set_conds(positive) # Run training loop try: - _run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res) + _run_training_loop( + guider, + train_sampler, + latents, + num_images, + seed, + bucket_mode, + multi_res, + ) finally: for m in mp.model.modules(): unpatch(m) @@ -977,7 +1054,7 @@ class TrainLoraNode(io.ComfyNode): return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader(io.ComfyNode): +class LoraModelLoader(io.ComfyNode):# @classmethod def define_schema(cls): return io.Schema(