ComfyUI/comfy_extras/nodes/nodes_train.py

1238 lines
41 KiB
Python

import logging
import os
import numpy as np
import safetensors
import torch
import torch.utils.checkpoint
import tqdm
from PIL import Image, ImageDraw, ImageFont
from typing_extensions import override
import comfy.model_management
import comfy.samplers
import comfy.sampler_helpers
import comfy.sd
import comfy.utils
from comfy import node_helpers
from comfy.cmd import folder_paths
from comfy.execution_context import current_execution_context
from comfy.utils import ProgressBar
from comfy.weight_adapter import adapters, adapter_maps
from comfy_api.latest import ComfyExtension, io, ui
from .nodes_custom_sampler import Noise_RandomNoise, Guider_Basic
logger = logging.getLogger(__name__)
class TrainGuider(Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
def outer_sample(
self,
noise,
latent_image,
sampler,
sigmas,
denoise_mask=None,
callback=None,
disable_pbar=False,
seed=None,
latent_shapes=None,
):
self.inner_model, self.conds, self.loaded_models = (
comfy.sampler_helpers.prepare_sampling(
self.model_patcher,
noise.shape,
self.conds,
self.model_options,
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
)
)
device = self.model_patcher.load_device
if denoise_mask is not None:
denoise_mask = comfy.sampler_helpers.prepare_mask(
denoise_mask, noise.shape, device
)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
comfy.samplers.cast_to_load_options(
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
)
try:
self.model_patcher.pre_run()
output = self.inner_sample(
noise,
latent_image,
device,
sampler,
sigmas,
denoise_mask,
callback,
disable_pbar,
seed,
latent_shapes=latent_shapes,
)
finally:
self.model_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.loaded_models
return output
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
newv = v
if isinstance(v, dict):
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
elif isinstance(v, torch.Tensor):
if full_size is None or v.size(0) == full_size:
newv = v[indicies]
elif isinstance(v, (list, tuple)) and len(v) == full_size:
newv = [v[i] for i in indicies]
new_dict[k] = newv
return new_dict
def process_cond_list(d, prefix=""):
if hasattr(d, "__iter__") and not hasattr(d, "items"):
for index, item in enumerate(d):
process_cond_list(item, f"{prefix}.{index}")
return d
elif hasattr(d, "items"):
for k, v in list(d.items()):
if isinstance(v, dict):
process_cond_list(v, f"{prefix}.{k}")
elif isinstance(v, torch.Tensor):
d[k] = v.clone()
elif isinstance(v, (list, tuple)):
for index, item in enumerate(v):
process_cond_list(item, f"{prefix}.{k}.{index}")
return d
class TrainSampler(comfy.samplers.Sampler):
def __init__(
self,
loss_fn,
optimizer,
loss_callback=None,
batch_size=1,
grad_acc=1,
total_steps=1,
seed=0,
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
self.batch_size = batch_size
self.total_steps = total_steps
self.grad_acc = grad_acc
self.seed = seed
self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset
# Bucket mode data
self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
else:
self.bucket_offsets = None
self.bucket_weights = None
self.num_images = None
def _init_bucket_data(self, bucket_latents):
"""Initialize bucket offsets and weights for sampling."""
self.bucket_offsets = [0]
bucket_sizes = []
for lat in bucket_latents:
bucket_sizes.append(lat.shape[0])
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
self.num_images = self.bucket_offsets[-1]
# Weights for sampling buckets proportional to their size
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
def fwd_bwd(
self,
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
):
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas, batch_noise, batch_latent, False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False,
)
model_wrap.conds["positive"] = [cond[i] for i in indicies]
batch_extra_args = make_batch_extra_option_dict(
extra_args, indicies, full_size=dataset_size
)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(
xt.requires_grad_(True),
batch_sigmas.requires_grad_(True),
**batch_extra_args,
)
loss = self.loss_fn(x0_pred, x0)
if bwd:
bwd_loss = loss / self.grad_acc
bwd_loss.backward()
return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
"""Generate random sigma values for a batch."""
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
for _ in range(batch_size)
]
return torch.tensor(batch_sigmas).to(device)
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
"""Execute one training step in bucket mode."""
# Sample bucket (weighted by size), then sample batch from bucket
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
bucket_size = bucket_latent.shape[0]
bucket_offset = self.bucket_offsets[bucket_idx]
# Sample indices from this bucket (use all if bucket_size < batch_size)
actual_batch_size = min(self.batch_size, bucket_size)
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
absolute_indices = [bucket_offset + idx for idx in relative_indices]
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond, # Use flattened cond with absolute indices
absolute_indices,
extra_args,
self.num_images,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
# todo: should this be "0" or scalar_tensor?
total_loss = torch.tensor(0.0)
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
batch_noise = noisegen.generate_noise(
{"samples": single_latent}
).to(single_latent.device)
batch_sigmas = (
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
)
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
single_latent,
cond,
[index],
extra_args,
dataset_size,
bwd=False,
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
def sample(
self,
model_wrap,
sigmas,
extra_args,
callback,
noise,
latent_image=None,
denoise_mask=None,
disable_pbar=False,
):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
ui_pbar = ProgressBar(self.total_steps)
for i in (
pbar := tqdm.trange(
self.total_steps,
desc="Training LoRA",
smoothing=0.01,
disable=not current_execution_context().server.receive_all_progress_notifications,
)
):
noisegen = Noise_RandomNoise(
self.seed + i * 1000
)
if self.bucket_latents is not None:
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
elif self.real_dataset is None:
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
else:
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
class BiasDiff(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def __call__(self, b):
org_dtype = b.dtype
return (b.to(self.bias) + self.bias).to(org_dtype)
def passive_memory_usage(self):
return self.bias.nelement() * self.bias.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
prev_point = (0, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = int(i / (steps - 1) * width)
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
return img
def find_all_highest_child_module_with_forward(
model: torch.nn.Module, result=None, name=None
):
if result is None:
result = []
elif hasattr(model, "forward") and not isinstance(
model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
):
result.append(model)
logger.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
return result
name = name or "root"
for next_name, child in model.named_children():
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
return result
def patch(m):
if not hasattr(m, "forward"):
return
org_forward = m.forward
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward
m.forward = checkpointing_fwd
def unpatch(m):
if hasattr(m, "org_forward"):
m.forward = m.org_forward
del m.org_forward
def _process_latents_bucket_mode(latents):
"""Process latents for bucket mode training.
Args:
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
Returns:
list of latent tensors
"""
bucket_latents = []
for latent_dict in latents:
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
return bucket_latents
def _process_latents_standard_mode(latents):
"""Process latents for standard (non-bucket) mode training.
Args:
latents: list of latent dicts or single latent dict
Returns:
Processed latents (tensor or list of tensors)
"""
if len(latents) == 1:
return latents[0]["samples"] # Single latent dict
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
return latent_list
def _process_conditioning(positive):
"""Process conditioning - either single list or list of lists.
Args:
positive: list of conditioning
Returns:
Flattened conditioning list
"""
if len(positive) == 1:
return positive[0] # Single conditioning list
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
return flat_positive
def _prepare_latents_and_count(latents, dtype, bucket_mode):
"""Convert latents to dtype and compute image counts.
Args:
latents: Latents (tensor, list of tensors, or bucket list)
dtype: Target dtype
bucket_mode: Whether bucket mode is enabled
Returns:
tuple: (processed_latents, num_images, multi_res)
"""
if bucket_mode:
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
latents = [t.to(dtype) for t in latents]
num_buckets = len(latents)
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
return latents, num_images, multi_res
# Non-bucket mode
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
multi_res = False
else:
logging.error(f"Invalid latents type: {type(latents)}")
num_images = 0
multi_res = False
return latents, num_images, multi_res
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
"""Validate conditioning count matches image count, expand if needed.
Args:
positive: Conditioning list
num_images: Number of images
bucket_mode: Whether bucket mode is enabled
Returns:
Validated/expanded conditioning list
Raises:
ValueError: If conditioning count doesn't match image count
"""
if bucket_mode:
return positive # Skip validation in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
return positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
return positive
def _load_existing_lora(existing_lora):
"""Load existing LoRA weights if provided.
Args:
existing_lora: LoRA filename or "[None]"
Returns:
tuple: (existing_weights dict, existing_steps int)
"""
if existing_lora == "[None]":
return {}, 0
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
existing_weights = {}
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
return existing_weights, existing_steps
def _create_weight_adapter(
module, module_name, existing_weights, algorithm, lora_dtype, rank
):
"""Create a weight adapter for a module with weight.
Args:
module: The module to create adapter for
module_name: Name of the module
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (train_adapter, lora_params dict)
"""
key = f"{module_name}.weight"
shape = module.weight.shape
lora_params = {}
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
# Try to load existing adapter
existing_adapter = None
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
module_name, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
if existing_adapter is None:
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
module.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_params[f"{module_name}.{name}"] = parameter
return train_adapter.train().requires_grad_(True), lora_params
else:
# 1D weight - use BiasDiff
diff = torch.nn.Parameter(
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
)
diff_module = BiasDiff(diff).train().requires_grad_(True)
lora_params[f"{module_name}.diff"] = diff
return diff_module, lora_params
def _create_bias_adapter(module, module_name, lora_dtype):
"""Create a bias adapter for a module with bias.
Args:
module: The module with bias
module_name: Name of the module
lora_dtype: dtype for LoRA weights
Returns:
tuple: (bias_module, lora_params dict)
"""
bias = torch.nn.Parameter(
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias).train().requires_grad_(True)
lora_params = {f"{module_name}.diff_b": bias}
return bias_module, lora_params
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
"""Setup all LoRA adapters on the model.
Args:
mp: Model patcher
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (lora_sd dict, all_weight_adapters list)
"""
lora_sd = {}
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
adapter, params = _create_weight_adapter(
m, n, existing_weights, algorithm, lora_dtype, rank
)
lora_sd.update(params)
key = f"{n}.weight"
mp.add_weight_wrapper(key, adapter)
all_weight_adapters.append(adapter)
if hasattr(m, "bias") and m.bias is not None:
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
lora_sd.update(bias_params)
key = f"{n}.bias"
mp.add_weight_wrapper(key, bias_adapter)
all_weight_adapters.append(bias_adapter)
return lora_sd, all_weight_adapters
def _create_optimizer(optimizer_name, parameters, learning_rate):
"""Create optimizer based on name.
Args:
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
parameters: Parameters to optimize
learning_rate: Learning rate
Returns:
Optimizer instance
"""
if optimizer_name == "Adam":
return torch.optim.Adam(parameters, lr=learning_rate)
elif optimizer_name == "AdamW":
return torch.optim.AdamW(parameters, lr=learning_rate)
elif optimizer_name == "SGD":
return torch.optim.SGD(parameters, lr=learning_rate)
elif optimizer_name == "RMSprop":
return torch.optim.RMSprop(parameters, lr=learning_rate)
def _create_loss_function(loss_function_name):
"""Create loss function based on name.
Args:
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
Returns:
Loss function instance
"""
if loss_function_name == "MSE":
return torch.nn.MSELoss()
elif loss_function_name == "L1":
return torch.nn.L1Loss()
elif loss_function_name == "Huber":
return torch.nn.HuberLoss()
elif loss_function_name == "SmoothL1":
return torch.nn.SmoothL1Loss()
def _run_training_loop(
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
):
"""Execute the training loop.
Args:
guider: The guider object
train_sampler: The training sampler
latents: Latent tensors
num_images: Number of images
seed: Random seed
bucket_mode: Whether bucket mode is enabled
multi_res: Whether multi-resolution mode is enabled
"""
sigmas = torch.tensor(range(num_images))
noise = Noise_RandomNoise(seed)
if bucket_mode:
# Use first bucket's first latent as dummy for guider
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": dummy_latent}),
dummy_latent,
train_sampler,
sigmas,
seed=noise.seed,
)
elif multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
else:
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
class TrainLoraNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TrainLoraNode",
display_name="Train LoRA",
category="training",
is_experimental=True,
is_input_list=True, # All inputs become lists
inputs=[
io.Model.Input("model", tooltip="The model to train the LoRA on."),
io.Latent.Input(
"latents",
tooltip="The Latents to use for training, serve as dataset/input of the model.",
),
io.Conditioning.Input(
"positive", tooltip="The positive conditioning to use for training."
),
io.Int.Input(
"batch_size",
default=1,
min=1,
max=10000,
tooltip="The batch size to use for training.",
),
io.Int.Input(
"grad_accumulation_steps",
default=1,
min=1,
max=1024,
tooltip="The number of gradient accumulation steps to use for training.",
),
io.Int.Input(
"steps",
default=16,
min=1,
max=100000,
tooltip="The number of steps to train the LoRA for.",
),
io.Float.Input(
"learning_rate",
default=0.0005,
min=0.0000001,
max=1.0,
step=0.0000001,
tooltip="The learning rate to use for training.",
),
io.Int.Input(
"rank",
default=8,
min=1,
max=128,
tooltip="The rank of the LoRA layers.",
),
io.Combo.Input(
"optimizer",
options=["AdamW", "Adam", "SGD", "RMSprop"],
default="AdamW",
tooltip="The optimizer to use for training.",
),
io.Combo.Input(
"loss_function",
options=["MSE", "L1", "Huber", "SmoothL1"],
default="MSE",
tooltip="The loss function to use for training.",
),
io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
),
io.Combo.Input(
"training_dtype",
options=["bf16", "fp32"],
default="bf16",
tooltip="The dtype to use for training.",
),
io.Combo.Input(
"lora_dtype",
options=["bf16", "fp32"],
default="bf16",
tooltip="The dtype to use for lora.",
),
io.Combo.Input(
"algorithm",
options=list(adapter_maps.keys()),
default=list(adapter_maps.keys())[0],
tooltip="The algorithm to use for training.",
),
io.Boolean.Input(
"gradient_checkpointing",
default=True,
tooltip="Use gradient checkpointing for training.",
),
io.Combo.Input(
"existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"],
default="[None]",
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
),
io.Boolean.Input(
"bucket_mode",
default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
),
],
outputs=[
io.Model.Output(
display_name="model", tooltip="Model with LoRA applied"
),
io.Custom("LORA_MODEL").Output(
display_name="lora", tooltip="LoRA weights"
),
io.Custom("LOSS_MAP").Output(
display_name="loss_map", tooltip="Loss history"
),
io.Int.Output(display_name="steps", tooltip="Total training steps"),
],
)
@classmethod
def execute(
cls,
model,
latents,
positive,
batch_size,
steps,
grad_accumulation_steps,
learning_rate,
rank,
optimizer,
loss_function,
seed,
training_dtype,
lora_dtype,
algorithm,
gradient_checkpointing,
existing_lora,
bucket_mode,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
batch_size = batch_size[0]
steps = steps[0]
grad_accumulation_steps = grad_accumulation_steps[0]
learning_rate = learning_rate[0]
rank = rank[0]
optimizer_name = optimizer[0]
loss_function_name = loss_function[0]
seed = seed[0]
training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
# Process latents based on mode
if bucket_mode:
latents = _process_latents_bucket_mode(latents)
else:
latents = _process_latents_standard_mode(latents)
# Process conditioning
positive = _process_conditioning(positive)
# Setup model and dtype
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
# Setup models for training
mp.model.requires_grad_(False)
# Load existing LoRA weights if provided
existing_weights, existing_steps = _load_existing_lora(existing_lora)
# Setup LoRA adapters
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
# Create optimizer and loss function
optimizer = _create_optimizer(
optimizer_name, lora_sd.values(), learning_rate
)
criterion = _create_loss_function(loss_function_name)
# Setup gradient checkpointing
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
[mp], memory_required=int(1e20), force_full_load=True
)
torch.cuda.empty_cache()
# Setup loss tracking
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
# Create sampler
if bucket_mode:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
)
else:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
# Setup guider
guider = TrainGuider(mp)
guider.set_conds(positive)
# Run training loop
try:
_run_training_loop(
guider,
train_sampler,
latents,
num_images,
seed,
bucket_mode,
multi_res,
)
finally:
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
# Finalize adapters
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode): #
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoraModelLoader",
display_name="Load LoRA Model",
category="loaders",
is_experimental=True,
inputs=[
io.Model.Input(
"model", tooltip="The diffusion model the LoRA will be applied to."
),
io.Custom("LORA_MODEL").Input(
"lora", tooltip="The LoRA model to apply to the diffusion model."
),
io.Float.Input(
"strength_model",
default=1.0,
min=-100.0,
max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.",
),
],
outputs=[
io.Model.Output(
display_name="model", tooltip="The modified diffusion model."
),
],
)
@classmethod
def execute(cls, model, lora, strength_model):
if strength_model == 0:
return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora)
class SaveLoRA(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveLoRA",
display_name="Save LoRA Weights",
category="loaders",
is_experimental=True,
is_output_node=True,
inputs=[
io.Custom("LORA_MODEL").Input(
"lora",
tooltip="The LoRA model to save. Do not use the model with LoRA layers.",
),
io.String.Input(
"prefix",
default="loras/ComfyUI_trained_lora",
tooltip="The prefix to use for the saved LoRA file.",
),
io.Int.Input(
"steps",
optional=True,
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
),
],
outputs=[],
)
@classmethod
def execute(cls, lora, prefix, steps=None):
output_dir = folder_paths.get_output_directory()
full_output_folder, filename, counter, subfolder, filename_prefix = (
folder_paths.get_save_image_path(prefix, output_dir)
)
if steps is None:
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
else:
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
safetensors.torch.save_file(lora, output_checkpoint)
return io.NodeOutput()
class LossGraphNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LossGraphNode",
display_name="Plot Loss Graph",
category="training",
is_experimental=True,
is_output_node=True,
inputs=[
io.Custom("LOSS_MAP").Input(
"loss", tooltip="Loss map from training node."
),
io.String.Input(
"filename_prefix",
default="loss_graph",
tooltip="Prefix for the saved loss graph image.",
),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
)
@classmethod
def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None):
loss_values = loss["loss"]
width, height = 800, 480
margin = 40
img = Image.new(
"RGB", (width + margin, height + margin), "white"
) # Extend canvas
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_values), max(loss_values)
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
steps = len(loss_values)
prev_point = (margin, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = margin + int(i / steps * width) # Scale X properly
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
draw.line(
[(margin, height), (width + margin, height)], fill="black", width=2
) # X-axis
font = None
try:
font = ImageFont.truetype("arial.ttf", 12)
except IOError:
font = ImageFont.load_default()
# Add axis labels
draw.text((5, height // 2), "Loss", font=font, fill="black")
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
# Add min/max loss values
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
draw.text(
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
)
# Convert PIL image to tensor for PreviewImage
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3]
# Return preview UI
return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls))
# ========== Extension Setup ==========
class TrainingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TrainLoraNode,
LoraModelLoader,
SaveLoRA,
LossGraphNode,
]
async def comfy_entrypoint() -> TrainingExtension:
return TrainingExtension()