ComfyUI/comfy_extras/nodes_train.py
2025-12-01 23:53:20 +08:00

1155 lines
39 KiB
Python

import logging
import os
import numpy as np
import safetensors
import torch
import torch.utils.checkpoint
from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont
from typing_extensions import override
import comfy.samplers
import comfy.sd
import comfy.utils
import comfy.model_management
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.weight_adapter import adapters, adapter_maps
from comfy_api.latest import ComfyExtension, io, ui
from comfy.utils import ProgressBar
def make_batch_extra_option_dict(d, indicies, full_size=None):
new_dict = {}
for k, v in d.items():
newv = v
if isinstance(v, dict):
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
elif isinstance(v, torch.Tensor):
if full_size is None or v.size(0) == full_size:
newv = v[indicies]
elif isinstance(v, (list, tuple)) and len(v) == full_size:
newv = [v[i] for i in indicies]
new_dict[k] = newv
return new_dict
def process_cond_list(d, prefix=""):
if hasattr(d, "__iter__") and not hasattr(d, "items"):
for index, item in enumerate(d):
process_cond_list(item, f"{prefix}.{index}")
return d
elif hasattr(d, "items"):
for k, v in list(d.items()):
if isinstance(v, dict):
process_cond_list(v, f"{prefix}.{k}")
elif isinstance(v, torch.Tensor):
d[k] = v.clone()
elif isinstance(v, (list, tuple)):
for index, item in enumerate(v):
process_cond_list(item, f"{prefix}.{k}.{index}")
return d
class TrainSampler(comfy.samplers.Sampler):
def __init__(
self,
loss_fn,
optimizer,
loss_callback=None,
batch_size=1,
grad_acc=1,
total_steps=1,
seed=0,
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
self.batch_size = batch_size
self.total_steps = total_steps
self.grad_acc = grad_acc
self.seed = seed
self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset
# Bucket mode data
self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi)
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
else:
self.bucket_offsets = None
self.bucket_weights = None
self.num_images = None
def _init_bucket_data(self, bucket_latents):
"""Initialize bucket offsets and weights for sampling."""
self.bucket_offsets = [0]
bucket_sizes = []
for lat in bucket_latents:
bucket_sizes.append(lat.shape[0])
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
self.num_images = self.bucket_offsets[-1]
# Weights for sampling buckets proportional to their size
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
def fwd_bwd(
self,
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
):
xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas, batch_noise, batch_latent, False
)
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise),
batch_latent,
False,
)
model_wrap.conds["positive"] = [cond[i] for i in indicies]
batch_extra_args = make_batch_extra_option_dict(
extra_args, indicies, full_size=dataset_size
)
with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(
xt.requires_grad_(True),
batch_sigmas.requires_grad_(True),
**batch_extra_args,
)
loss = self.loss_fn(x0_pred, x0)
if bwd:
bwd_loss = loss / self.grad_acc
bwd_loss.backward()
return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
"""Generate random sigma values for a batch."""
batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
for _ in range(batch_size)
]
return torch.tensor(batch_sigmas).to(device)
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
"""Execute one training step in bucket mode."""
# Sample bucket (weighted by size), then sample batch from bucket
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
bucket_size = bucket_latent.shape[0]
bucket_offset = self.bucket_offsets[bucket_idx]
# Sample indices from this bucket (use all if bucket_size < batch_size)
actual_batch_size = min(self.batch_size, bucket_size)
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
absolute_indices = [bucket_offset + idx for idx in relative_indices]
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond, # Use flattened cond with absolute indices
absolute_indices,
extra_args,
self.num_images,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
)
if self.loss_callback:
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
total_loss = 0
for index in indicies:
single_latent = self.real_dataset[index].to(latent_image)
batch_noise = noisegen.generate_noise(
{"samples": single_latent}
).to(single_latent.device)
batch_sigmas = (
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
)
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
single_latent,
cond,
[index],
extra_args,
dataset_size,
bwd=False,
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
def sample(
self,
model_wrap,
sigmas,
extra_args,
callback,
noise,
latent_image=None,
denoise_mask=None,
disable_pbar=False,
):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
ui_pbar = ProgressBar(self.total_steps)
for i in (
pbar := trange(
self.total_steps,
desc="Training LoRA",
smoothing=0.01,
disable=not comfy.utils.PROGRESS_BAR_ENABLED,
)
):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
self.seed + i * 1000
)
if self.bucket_latents is not None:
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
elif self.real_dataset is None:
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
else:
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
class BiasDiff(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def __call__(self, b):
org_dtype = b.dtype
return (b.to(self.bias) + self.bias).to(org_dtype)
def passive_memory_usage(self):
return self.bias.nelement() * self.bias.element_size()
def move_to(self, device):
self.to(device=device)
return self.passive_memory_usage()
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
prev_point = (0, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = int(i / (steps - 1) * width)
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
return img
def find_all_highest_child_module_with_forward(
model: torch.nn.Module, result=None, name=None
):
if result is None:
result = []
elif hasattr(model, "forward") and not isinstance(
model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
):
result.append(model)
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
return result
name = name or "root"
for next_name, child in model.named_children():
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
return result
def patch(m):
if not hasattr(m, "forward"):
return
org_forward = m.forward
def fwd(args, kwargs):
return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward
m.forward = checkpointing_fwd
def unpatch(m):
if hasattr(m, "org_forward"):
m.forward = m.org_forward
del m.org_forward
def _process_latents_bucket_mode(latents):
"""Process latents for bucket mode training.
Args:
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
Returns:
list of latent tensors
"""
bucket_latents = []
for latent_dict in latents:
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
return bucket_latents
def _process_latents_standard_mode(latents):
"""Process latents for standard (non-bucket) mode training.
Args:
latents: list of latent dicts or single latent dict
Returns:
Processed latents (tensor or list of tensors)
"""
if len(latents) == 1:
return latents[0]["samples"] # Single latent dict
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
return latent_list
def _process_conditioning(positive):
"""Process conditioning - either single list or list of lists.
Args:
positive: list of conditioning
Returns:
Flattened conditioning list
"""
if len(positive) == 1:
return positive[0] # Single conditioning list
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
return flat_positive
def _prepare_latents_and_count(latents, dtype, bucket_mode):
"""Convert latents to dtype and compute image counts.
Args:
latents: Latents (tensor, list of tensors, or bucket list)
dtype: Target dtype
bucket_mode: Whether bucket mode is enabled
Returns:
tuple: (processed_latents, num_images, multi_res)
"""
if bucket_mode:
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
latents = [t.to(dtype) for t in latents]
num_buckets = len(latents)
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
return latents, num_images, multi_res
# Non-bucket mode
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
multi_res = False
else:
logging.error(f"Invalid latents type: {type(latents)}")
num_images = 0
multi_res = False
return latents, num_images, multi_res
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
"""Validate conditioning count matches image count, expand if needed.
Args:
positive: Conditioning list
num_images: Number of images
bucket_mode: Whether bucket mode is enabled
Returns:
Validated/expanded conditioning list
Raises:
ValueError: If conditioning count doesn't match image count
"""
if bucket_mode:
return positive # Skip validation in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
return positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
return positive
def _load_existing_lora(existing_lora):
"""Load existing LoRA weights if provided.
Args:
existing_lora: LoRA filename or "[None]"
Returns:
tuple: (existing_weights dict, existing_steps int)
"""
if existing_lora == "[None]":
return {}, 0
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
existing_weights = {}
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
return existing_weights, existing_steps
def _create_weight_adapter(module, module_name, existing_weights, algorithm, lora_dtype, rank):
"""Create a weight adapter for a module with weight.
Args:
module: The module to create adapter for
module_name: Name of the module
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (train_adapter, lora_params dict)
"""
key = f"{module_name}.weight"
shape = module.weight.shape
lora_params = {}
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
# Try to load existing adapter
existing_adapter = None
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
module_name, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
if existing_adapter is None:
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
module.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_params[f"{module_name}.{name}"] = parameter
return train_adapter, lora_params
else:
# 1D weight - use BiasDiff
diff = torch.nn.Parameter(
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
)
diff_module = BiasDiff(diff)
lora_params[f"{module_name}.diff"] = diff
return diff_module, lora_params
def _create_bias_adapter(module, module_name, lora_dtype):
"""Create a bias adapter for a module with bias.
Args:
module: The module with bias
module_name: Name of the module
lora_dtype: dtype for LoRA weights
Returns:
tuple: (bias_module, lora_params dict)
"""
bias = torch.nn.Parameter(
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias)
lora_params = {f"{module_name}.diff_b": bias}
return bias_module, lora_params
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
"""Setup all LoRA adapters on the model.
Args:
mp: Model patcher
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (lora_sd dict, all_weight_adapters list)
"""
lora_sd = {}
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
adapter, params = _create_weight_adapter(
m, n, existing_weights, algorithm, lora_dtype, rank
)
lora_sd.update(params)
key = f"{n}.weight"
mp.add_weight_wrapper(key, adapter)
all_weight_adapters.append(adapter)
if hasattr(m, "bias") and m.bias is not None:
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
lora_sd.update(bias_params)
key = f"{n}.bias"
mp.add_weight_wrapper(key, bias_adapter)
all_weight_adapters.append(bias_adapter)
return lora_sd, all_weight_adapters
def _create_optimizer(optimizer_name, parameters, learning_rate):
"""Create optimizer based on name.
Args:
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
parameters: Parameters to optimize
learning_rate: Learning rate
Returns:
Optimizer instance
"""
if optimizer_name == "Adam":
return torch.optim.Adam(parameters, lr=learning_rate)
elif optimizer_name == "AdamW":
return torch.optim.AdamW(parameters, lr=learning_rate)
elif optimizer_name == "SGD":
return torch.optim.SGD(parameters, lr=learning_rate)
elif optimizer_name == "RMSprop":
return torch.optim.RMSprop(parameters, lr=learning_rate)
def _create_loss_function(loss_function_name):
"""Create loss function based on name.
Args:
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
Returns:
Loss function instance
"""
if loss_function_name == "MSE":
return torch.nn.MSELoss()
elif loss_function_name == "L1":
return torch.nn.L1Loss()
elif loss_function_name == "Huber":
return torch.nn.HuberLoss()
elif loss_function_name == "SmoothL1":
return torch.nn.SmoothL1Loss()
def _run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res):
"""Execute the training loop.
Args:
guider: The guider object
train_sampler: The training sampler
latents: Latent tensors
num_images: Number of images
seed: Random seed
bucket_mode: Whether bucket mode is enabled
multi_res: Whether multi-resolution mode is enabled
"""
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if bucket_mode:
# Use first bucket's first latent as dummy for guider
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": dummy_latent}),
dummy_latent,
train_sampler,
sigmas,
seed=noise.seed,
)
elif multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
else:
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
class TrainLoraNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TrainLoraNode",
display_name="Train LoRA",
category="training",
is_experimental=True,
is_input_list=True, # All inputs become lists
inputs=[
io.Model.Input("model", tooltip="The model to train the LoRA on."),
io.Latent.Input(
"latents",
tooltip="The Latents to use for training, serve as dataset/input of the model.",
),
io.Conditioning.Input(
"positive", tooltip="The positive conditioning to use for training."
),
io.Int.Input(
"batch_size",
default=1,
min=1,
max=10000,
tooltip="The batch size to use for training.",
),
io.Int.Input(
"grad_accumulation_steps",
default=1,
min=1,
max=1024,
tooltip="The number of gradient accumulation steps to use for training.",
),
io.Int.Input(
"steps",
default=16,
min=1,
max=100000,
tooltip="The number of steps to train the LoRA for.",
),
io.Float.Input(
"learning_rate",
default=0.0005,
min=0.0000001,
max=1.0,
step=0.0000001,
tooltip="The learning rate to use for training.",
),
io.Int.Input(
"rank",
default=8,
min=1,
max=128,
tooltip="The rank of the LoRA layers.",
),
io.Combo.Input(
"optimizer",
options=["AdamW", "Adam", "SGD", "RMSprop"],
default="AdamW",
tooltip="The optimizer to use for training.",
),
io.Combo.Input(
"loss_function",
options=["MSE", "L1", "Huber", "SmoothL1"],
default="MSE",
tooltip="The loss function to use for training.",
),
io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
),
io.Combo.Input(
"training_dtype",
options=["bf16", "fp32"],
default="bf16",
tooltip="The dtype to use for training.",
),
io.Combo.Input(
"lora_dtype",
options=["bf16", "fp32"],
default="bf16",
tooltip="The dtype to use for lora.",
),
io.Combo.Input(
"algorithm",
options=list(adapter_maps.keys()),
default=list(adapter_maps.keys())[0],
tooltip="The algorithm to use for training.",
),
io.Boolean.Input(
"gradient_checkpointing",
default=True,
tooltip="Use gradient checkpointing for training.",
),
io.Combo.Input(
"existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"],
default="[None]",
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
),
io.Boolean.Input(
"bucket_mode",
default=False,
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
),
io.Boolean.Input(
"offloading",
default=False,
tooltip="",
),
],
outputs=[
io.Model.Output(
display_name="model", tooltip="Model with LoRA applied"
),
io.Custom("LORA_MODEL").Output(
display_name="lora", tooltip="LoRA weights"
),
io.Custom("LOSS_MAP").Output(
display_name="loss_map", tooltip="Loss history"
),
io.Int.Output(display_name="steps", tooltip="Total training steps"),
],
)
@classmethod
def execute(
cls,
model,
latents,
positive,
batch_size,
steps,
grad_accumulation_steps,
learning_rate,
rank,
optimizer,
loss_function,
seed,
training_dtype,
lora_dtype,
algorithm,
gradient_checkpointing,
existing_lora,
bucket_mode,
offloading,
):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
batch_size = batch_size[0]
steps = steps[0]
grad_accumulation_steps = grad_accumulation_steps[0]
learning_rate = learning_rate[0]
rank = rank[0]
optimizer_name = optimizer[0]
loss_function_name = loss_function[0]
seed = seed[0]
training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
# Process latents based on mode
if bucket_mode:
latents = _process_latents_bucket_mode(latents)
else:
latents = _process_latents_standard_mode(latents)
# Process conditioning
positive = _process_conditioning(positive)
# Setup model and dtype
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(latents, dtype, bucket_mode)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
# Load existing LoRA weights if provided
existing_weights, existing_steps = _load_existing_lora(existing_lora)
# Setup LoRA adapters
lora_sd, all_weight_adapters = _setup_lora_adapters(
mp, existing_weights, algorithm, lora_dtype, rank
)
# Create optimizer and loss function
optimizer = _create_optimizer(optimizer_name, lora_sd.values(), learning_rate)
criterion = _create_loss_function(loss_function_name)
# Setup gradient checkpointing
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m)
# Setup models for training
mp.model.requires_grad_(False)
torch.cuda.empty_cache()
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=not offloading
)
torch.cuda.empty_cache()
# Setup loss tracking
loss_map = {"loss": []}
def loss_callback(loss):
loss_map["loss"].append(loss)
# Create sampler
if bucket_mode:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
)
else:
train_sampler = TrainSampler(
criterion,
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
)
# Setup guider
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive)
# Run training loop
try:
_run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res)
finally:
for m in mp.model.modules():
unpatch(m)
del train_sampler, optimizer
# Finalize adapters
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoraModelLoader",
display_name="Load LoRA Model",
category="loaders",
is_experimental=True,
inputs=[
io.Model.Input(
"model", tooltip="The diffusion model the LoRA will be applied to."
),
io.Custom("LORA_MODEL").Input(
"lora", tooltip="The LoRA model to apply to the diffusion model."
),
io.Float.Input(
"strength_model",
default=1.0,
min=-100.0,
max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.",
),
],
outputs=[
io.Model.Output(
display_name="model", tooltip="The modified diffusion model."
),
],
)
@classmethod
def execute(cls, model, lora, strength_model):
if strength_model == 0:
return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora)
class SaveLoRA(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveLoRA",
display_name="Save LoRA Weights",
category="loaders",
is_experimental=True,
is_output_node=True,
inputs=[
io.Custom("LORA_MODEL").Input(
"lora",
tooltip="The LoRA model to save. Do not use the model with LoRA layers.",
),
io.String.Input(
"prefix",
default="loras/ComfyUI_trained_lora",
tooltip="The prefix to use for the saved LoRA file.",
),
io.Int.Input(
"steps",
optional=True,
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
),
],
outputs=[],
)
@classmethod
def execute(cls, lora, prefix, steps=None):
output_dir = folder_paths.get_output_directory()
full_output_folder, filename, counter, subfolder, filename_prefix = (
folder_paths.get_save_image_path(prefix, output_dir)
)
if steps is None:
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
else:
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
safetensors.torch.save_file(lora, output_checkpoint)
return io.NodeOutput()
class LossGraphNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LossGraphNode",
display_name="Plot Loss Graph",
category="training",
is_experimental=True,
is_output_node=True,
inputs=[
io.Custom("LOSS_MAP").Input(
"loss", tooltip="Loss map from training node."
),
io.String.Input(
"filename_prefix",
default="loss_graph",
tooltip="Prefix for the saved loss graph image.",
),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
)
@classmethod
def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None):
loss_values = loss["loss"]
width, height = 800, 480
margin = 40
img = Image.new(
"RGB", (width + margin, height + margin), "white"
) # Extend canvas
draw = ImageDraw.Draw(img)
min_loss, max_loss = min(loss_values), max(loss_values)
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
steps = len(loss_values)
prev_point = (margin, height - int(scaled_loss[0] * height))
for i, l in enumerate(scaled_loss[1:], start=1):
x = margin + int(i / steps * width) # Scale X properly
y = height - int(l * height)
draw.line([prev_point, (x, y)], fill="blue", width=2)
prev_point = (x, y)
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
draw.line(
[(margin, height), (width + margin, height)], fill="black", width=2
) # X-axis
font = None
try:
font = ImageFont.truetype("arial.ttf", 12)
except IOError:
font = ImageFont.load_default()
# Add axis labels
draw.text((5, height // 2), "Loss", font=font, fill="black")
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
# Add min/max loss values
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
draw.text(
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
)
# Convert PIL image to tensor for PreviewImage
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3]
# Return preview UI
return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls))
# ========== Extension Setup ==========
class TrainingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TrainLoraNode,
LoraModelLoader,
SaveLoRA,
LossGraphNode,
]
async def comfy_entrypoint() -> TrainingExtension:
return TrainingExtension()