Custom guider for correct offloading behavior

This commit is contained in:
Kohaku-Blueleaf 2025-12-05 17:24:55 +08:00
parent bf573e94a2
commit 4004af3290
2 changed files with 97 additions and 19 deletions

View File

@ -122,20 +122,21 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor( executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling, _prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
) )
return executor.execute(model, noise_shape, conds, model_options=model_options) return executor.execute(model, noise_shape, conds, model_options=model_options, skip_load_model=skip_load_model)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
real_model: BaseModel = None real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options) models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update? models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) models_list = [model] if not skip_load_model else []
comfy.model_management.load_models_gpu(models_list + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
real_model = model.model real_model = model.model
return real_model, conds, models return real_model, conds, models

View File

@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont
from typing_extensions import override from typing_extensions import override
import comfy.samplers import comfy.samplers
import comfy.sampler_helpers
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
def outer_sample(
self,
noise,
latent_image,
sampler,
sigmas,
denoise_mask=None,
callback=None,
disable_pbar=False,
seed=None,
latent_shapes=None,
):
self.inner_model, self.conds, self.loaded_models = (
comfy.sampler_helpers.prepare_sampling(
self.model_patcher,
noise.shape,
self.conds,
self.model_options,
skip_load_model=True, # skip load model as we manage it in TrainLoraNode.execute()
)
)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(
denoise_mask, noise.shape, device
)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
comfy.samplers.cast_to_load_options(
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
)
try:
self.model_patcher.pre_run()
output = self.inner_sample(
noise,
latent_image,
device,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
latent_shapes=latent_shapes,
)
finally:
self.model_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.loaded_models
return output
def make_batch_extra_option_dict(d, indicies, full_size=None): def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {} new_dict = {}
for k, v in d.items(): for k, v in d.items():
@ -77,7 +140,9 @@ class TrainSampler(comfy.samplers.Sampler):
self.training_dtype = training_dtype self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset self.real_dataset: list[torch.Tensor] | None = real_dataset
# Bucket mode data # Bucket mode data
self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi) self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
# Precompute bucket offsets and weights for sampling # Precompute bucket offsets and weights for sampling
if bucket_latents is not None: if bucket_latents is not None:
self._init_bucket_data(bucket_latents) self._init_bucket_data(bucket_latents)
@ -511,7 +576,9 @@ def _load_existing_lora(existing_lora):
return existing_weights, existing_steps return existing_weights, existing_steps
def _create_weight_adapter(module, module_name, existing_weights, algorithm, lora_dtype, rank): def _create_weight_adapter(
module, module_name, existing_weights, algorithm, lora_dtype, rank
):
"""Create a weight adapter for a module with weight. """Create a weight adapter for a module with weight.
Args: Args:
@ -663,7 +730,9 @@ def _create_loss_function(loss_function_name):
return torch.nn.SmoothL1Loss() return torch.nn.SmoothL1Loss()
def _run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res): def _run_training_loop(
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
):
"""Execute the training loop. """Execute the training loop.
Args: Args:
@ -815,11 +884,6 @@ class TrainLoraNode(io.ComfyNode):
default=False, default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
), ),
io.Boolean.Input(
"offloading",
default=False,
tooltip="",
),
], ],
outputs=[ outputs=[
io.Model.Output( io.Model.Output(
@ -855,7 +919,6 @@ class TrainLoraNode(io.ComfyNode):
gradient_checkpointing, gradient_checkpointing,
existing_lora, existing_lora,
bucket_mode, bucket_mode,
offloading,
): ):
# Extract scalars from lists (due to is_input_list=True) # Extract scalars from lists (due to is_input_list=True)
model = model[0] model = model[0]
@ -890,7 +953,9 @@ class TrainLoraNode(io.ComfyNode):
mp.set_model_compute_dtype(dtype) mp.set_model_compute_dtype(dtype)
# Prepare latents and compute counts # Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(latents, dtype, bucket_mode) latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
)
# Validate and expand conditioning # Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
@ -905,7 +970,9 @@ class TrainLoraNode(io.ComfyNode):
) )
# Create optimizer and loss function # Create optimizer and loss function
optimizer = _create_optimizer(optimizer_name, lora_sd.values(), learning_rate) optimizer = _create_optimizer(
optimizer_name, lora_sd.values(), learning_rate
)
criterion = _create_loss_function(loss_function_name) criterion = _create_loss_function(loss_function_name)
# Setup gradient checkpointing # Setup gradient checkpointing
@ -918,8 +985,10 @@ class TrainLoraNode(io.ComfyNode):
# Setup models for training # Setup models for training
mp.model.requires_grad_(False) mp.model.requires_grad_(False)
torch.cuda.empty_cache() torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu( comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=not offloading [mp], memory_required=1e20, force_full_load=True
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -956,12 +1025,20 @@ class TrainLoraNode(io.ComfyNode):
) )
# Setup guider # Setup guider
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider = TrainGuider(mp)
guider.set_conds(positive) guider.set_conds(positive)
# Run training loop # Run training loop
try: try:
_run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res) _run_training_loop(
guider,
train_sampler,
latents,
num_images,
seed,
bucket_mode,
multi_res,
)
finally: finally:
for m in mp.model.modules(): for m in mp.model.modules():
unpatch(m) unpatch(m)
@ -977,7 +1054,7 @@ class TrainLoraNode(io.ComfyNode):
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode): class LoraModelLoader(io.ComfyNode):#
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(