mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
Custom guider for correct offloading behavior
This commit is contained in:
parent
bf573e94a2
commit
4004af3290
@ -122,20 +122,21 @@ def estimate_memory(model, noise_shape, conds):
|
|||||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||||
return memory_required, minimum_memory_required
|
return memory_required, minimum_memory_required
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
_prepare_sampling,
|
_prepare_sampling,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
return executor.execute(model, noise_shape, conds, model_options=model_options)
|
return executor.execute(model, noise_shape, conds, model_options=model_options, skip_load_model=skip_load_model)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, skip_load_model=False):
|
||||||
real_model: BaseModel = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
models_list = [model] if not skip_load_model else []
|
||||||
|
comfy.model_management.load_models_gpu(models_list + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFont
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
|
import comfy.sampler_helpers
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -21,6 +22,68 @@ from comfy_api.latest import ComfyExtension, io, ui
|
|||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
|
|
||||||
|
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||||
|
"""
|
||||||
|
CFGGuider with modifications for training specific logic
|
||||||
|
"""
|
||||||
|
def outer_sample(
|
||||||
|
self,
|
||||||
|
noise,
|
||||||
|
latent_image,
|
||||||
|
sampler,
|
||||||
|
sigmas,
|
||||||
|
denoise_mask=None,
|
||||||
|
callback=None,
|
||||||
|
disable_pbar=False,
|
||||||
|
seed=None,
|
||||||
|
latent_shapes=None,
|
||||||
|
):
|
||||||
|
self.inner_model, self.conds, self.loaded_models = (
|
||||||
|
comfy.sampler_helpers.prepare_sampling(
|
||||||
|
self.model_patcher,
|
||||||
|
noise.shape,
|
||||||
|
self.conds,
|
||||||
|
self.model_options,
|
||||||
|
skip_load_model=True, # skip load model as we manage it in TrainLoraNode.execute()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
|
if denoise_mask is not None:
|
||||||
|
denoise_mask = comfy.sampler_helpers.prepare_mask(
|
||||||
|
denoise_mask, noise.shape, device
|
||||||
|
)
|
||||||
|
|
||||||
|
noise = noise.to(device)
|
||||||
|
latent_image = latent_image.to(device)
|
||||||
|
sigmas = sigmas.to(device)
|
||||||
|
comfy.samplers.cast_to_load_options(
|
||||||
|
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.model_patcher.pre_run()
|
||||||
|
output = self.inner_sample(
|
||||||
|
noise,
|
||||||
|
latent_image,
|
||||||
|
device,
|
||||||
|
sampler,
|
||||||
|
sigmas,
|
||||||
|
denoise_mask,
|
||||||
|
callback,
|
||||||
|
disable_pbar,
|
||||||
|
seed,
|
||||||
|
latent_shapes=latent_shapes,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self.model_patcher.cleanup()
|
||||||
|
|
||||||
|
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||||
|
del self.inner_model
|
||||||
|
del self.loaded_models
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
@ -77,7 +140,9 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.training_dtype = training_dtype
|
self.training_dtype = training_dtype
|
||||||
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
||||||
# Bucket mode data
|
# Bucket mode data
|
||||||
self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi)
|
self.bucket_latents: list[torch.Tensor] | None = (
|
||||||
|
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||||
|
)
|
||||||
# Precompute bucket offsets and weights for sampling
|
# Precompute bucket offsets and weights for sampling
|
||||||
if bucket_latents is not None:
|
if bucket_latents is not None:
|
||||||
self._init_bucket_data(bucket_latents)
|
self._init_bucket_data(bucket_latents)
|
||||||
@ -511,7 +576,9 @@ def _load_existing_lora(existing_lora):
|
|||||||
return existing_weights, existing_steps
|
return existing_weights, existing_steps
|
||||||
|
|
||||||
|
|
||||||
def _create_weight_adapter(module, module_name, existing_weights, algorithm, lora_dtype, rank):
|
def _create_weight_adapter(
|
||||||
|
module, module_name, existing_weights, algorithm, lora_dtype, rank
|
||||||
|
):
|
||||||
"""Create a weight adapter for a module with weight.
|
"""Create a weight adapter for a module with weight.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -663,7 +730,9 @@ def _create_loss_function(loss_function_name):
|
|||||||
return torch.nn.SmoothL1Loss()
|
return torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
|
|
||||||
def _run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res):
|
def _run_training_loop(
|
||||||
|
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
|
||||||
|
):
|
||||||
"""Execute the training loop.
|
"""Execute the training loop.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -815,11 +884,6 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default=False,
|
default=False,
|
||||||
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
||||||
),
|
),
|
||||||
io.Boolean.Input(
|
|
||||||
"offloading",
|
|
||||||
default=False,
|
|
||||||
tooltip="",
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(
|
io.Model.Output(
|
||||||
@ -855,7 +919,6 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
bucket_mode,
|
bucket_mode,
|
||||||
offloading,
|
|
||||||
):
|
):
|
||||||
# Extract scalars from lists (due to is_input_list=True)
|
# Extract scalars from lists (due to is_input_list=True)
|
||||||
model = model[0]
|
model = model[0]
|
||||||
@ -890,7 +953,9 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# Prepare latents and compute counts
|
||||||
latents, num_images, multi_res = _prepare_latents_and_count(latents, dtype, bucket_mode)
|
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||||
|
latents, dtype, bucket_mode
|
||||||
|
)
|
||||||
|
|
||||||
# Validate and expand conditioning
|
# Validate and expand conditioning
|
||||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||||
@ -905,7 +970,9 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create optimizer and loss function
|
# Create optimizer and loss function
|
||||||
optimizer = _create_optimizer(optimizer_name, lora_sd.values(), learning_rate)
|
optimizer = _create_optimizer(
|
||||||
|
optimizer_name, lora_sd.values(), learning_rate
|
||||||
|
)
|
||||||
criterion = _create_loss_function(loss_function_name)
|
criterion = _create_loss_function(loss_function_name)
|
||||||
|
|
||||||
# Setup gradient checkpointing
|
# Setup gradient checkpointing
|
||||||
@ -918,8 +985,10 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
# Setup models for training
|
# Setup models for training
|
||||||
mp.model.requires_grad_(False)
|
mp.model.requires_grad_(False)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
# With force_full_load=False we should be able to have offloading
|
||||||
|
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
[mp], memory_required=1e20, force_full_load=not offloading
|
[mp], memory_required=1e20, force_full_load=True
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -956,12 +1025,20 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
guider = TrainGuider(mp)
|
||||||
guider.set_conds(positive)
|
guider.set_conds(positive)
|
||||||
|
|
||||||
# Run training loop
|
# Run training loop
|
||||||
try:
|
try:
|
||||||
_run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res)
|
_run_training_loop(
|
||||||
|
guider,
|
||||||
|
train_sampler,
|
||||||
|
latents,
|
||||||
|
num_images,
|
||||||
|
seed,
|
||||||
|
bucket_mode,
|
||||||
|
multi_res,
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
for m in mp.model.modules():
|
for m in mp.model.modules():
|
||||||
unpatch(m)
|
unpatch(m)
|
||||||
@ -977,7 +1054,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
class LoraModelLoader(io.ComfyNode):
|
class LoraModelLoader(io.ComfyNode):#
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user