mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 05:23:37 +08:00
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
When training_dtype is set to "none" and the model's native dtype is float16, GradScaler was unconditionally enabled. However, GradScaler does not support bfloat16 gradients (only float16/float32), causing a NotImplementedError when lora_dtype is "bf16" (the default). Fix by only enabling GradScaler when LoRA parameters are not in bfloat16, since bfloat16 has the same exponent range as float32 and does not need gradient scaling to avoid underflow. Fixes #13124
1493 lines
52 KiB
Python
1493 lines
52 KiB
Python
import logging
|
|
import os
|
|
|
|
import numpy as np
|
|
import safetensors
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.checkpoint
|
|
from tqdm.auto import trange
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
from typing_extensions import override
|
|
|
|
import comfy.samplers
|
|
import comfy.sampler_helpers
|
|
import comfy.sd
|
|
import comfy.utils
|
|
import comfy.model_management
|
|
from comfy.cli_args import args, PerformanceFeature
|
|
import comfy_extras.nodes_custom_sampler
|
|
import folder_paths
|
|
import node_helpers
|
|
from comfy.weight_adapter import adapters, adapter_maps
|
|
from comfy.weight_adapter.bypass import BypassInjectionManager
|
|
from comfy_api.latest import ComfyExtension, io, ui
|
|
from comfy.utils import ProgressBar
|
|
|
|
|
|
class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|
"""
|
|
CFGGuider with modifications for training specific logic
|
|
"""
|
|
|
|
def __init__(self, *args, offloading=False, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.offloading = offloading
|
|
|
|
def outer_sample(
|
|
self,
|
|
noise,
|
|
latent_image,
|
|
sampler,
|
|
sigmas,
|
|
denoise_mask=None,
|
|
callback=None,
|
|
disable_pbar=False,
|
|
seed=None,
|
|
latent_shapes=None,
|
|
):
|
|
self.inner_model, self.conds, self.loaded_models = (
|
|
comfy.sampler_helpers.prepare_sampling(
|
|
self.model_patcher,
|
|
noise.shape,
|
|
self.conds,
|
|
self.model_options,
|
|
force_full_load=not self.offloading,
|
|
force_offload=self.offloading,
|
|
)
|
|
)
|
|
torch.cuda.empty_cache()
|
|
device = self.model_patcher.load_device
|
|
|
|
if denoise_mask is not None:
|
|
denoise_mask = comfy.sampler_helpers.prepare_mask(
|
|
denoise_mask, noise.shape, device
|
|
)
|
|
|
|
noise = noise.to(device)
|
|
latent_image = latent_image.to(device)
|
|
sigmas = sigmas.to(device)
|
|
comfy.samplers.cast_to_load_options(
|
|
self.model_options, device=device, dtype=self.model_patcher.model_dtype()
|
|
)
|
|
|
|
try:
|
|
self.model_patcher.pre_run()
|
|
output = self.inner_sample(
|
|
noise,
|
|
latent_image,
|
|
device,
|
|
sampler,
|
|
sigmas,
|
|
denoise_mask,
|
|
callback,
|
|
disable_pbar,
|
|
seed,
|
|
latent_shapes=latent_shapes,
|
|
)
|
|
finally:
|
|
self.model_patcher.cleanup()
|
|
|
|
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
|
del self.inner_model
|
|
del self.loaded_models
|
|
return output
|
|
|
|
|
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
|
new_dict = {}
|
|
for k, v in d.items():
|
|
newv = v
|
|
if isinstance(v, dict):
|
|
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
|
|
elif isinstance(v, torch.Tensor):
|
|
if full_size is None or v.size(0) == full_size:
|
|
newv = v[indicies]
|
|
elif isinstance(v, (list, tuple)) and len(v) == full_size:
|
|
newv = [v[i] for i in indicies]
|
|
new_dict[k] = newv
|
|
return new_dict
|
|
|
|
|
|
def process_cond_list(d, prefix=""):
|
|
if hasattr(d, "__iter__") and not hasattr(d, "items"):
|
|
for index, item in enumerate(d):
|
|
process_cond_list(item, f"{prefix}.{index}")
|
|
return d
|
|
elif hasattr(d, "items"):
|
|
for k, v in list(d.items()):
|
|
if isinstance(v, dict):
|
|
process_cond_list(v, f"{prefix}.{k}")
|
|
elif isinstance(v, torch.Tensor):
|
|
d[k] = v.clone()
|
|
elif isinstance(v, (list, tuple)):
|
|
for index, item in enumerate(v):
|
|
process_cond_list(item, f"{prefix}.{k}.{index}")
|
|
return d
|
|
|
|
|
|
class TrainSampler(comfy.samplers.Sampler):
|
|
def __init__(
|
|
self,
|
|
loss_fn,
|
|
optimizer,
|
|
loss_callback=None,
|
|
batch_size=1,
|
|
grad_acc=1,
|
|
total_steps=1,
|
|
seed=0,
|
|
training_dtype=torch.bfloat16,
|
|
real_dataset=None,
|
|
bucket_latents=None,
|
|
use_grad_scaler=False,
|
|
):
|
|
self.loss_fn = loss_fn
|
|
self.optimizer = optimizer
|
|
self.loss_callback = loss_callback
|
|
self.batch_size = batch_size
|
|
self.total_steps = total_steps
|
|
self.grad_acc = grad_acc
|
|
self.seed = seed
|
|
self.training_dtype = training_dtype
|
|
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
|
# Bucket mode data
|
|
self.bucket_latents: list[torch.Tensor] | None = (
|
|
bucket_latents # list of (Bi, C, Hi, Wi)
|
|
)
|
|
# GradScaler for fp16 training
|
|
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
|
# Precompute bucket offsets and weights for sampling
|
|
if bucket_latents is not None:
|
|
self._init_bucket_data(bucket_latents)
|
|
else:
|
|
self.bucket_offsets = None
|
|
self.bucket_weights = None
|
|
self.num_images = None
|
|
|
|
def _init_bucket_data(self, bucket_latents):
|
|
"""Initialize bucket offsets and weights for sampling."""
|
|
self.bucket_offsets = [0]
|
|
bucket_sizes = []
|
|
for lat in bucket_latents:
|
|
bucket_sizes.append(lat.shape[0])
|
|
self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0])
|
|
self.num_images = self.bucket_offsets[-1]
|
|
# Weights for sampling buckets proportional to their size
|
|
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
|
|
|
|
def fwd_bwd(
|
|
self,
|
|
model_wrap,
|
|
batch_sigmas,
|
|
batch_noise,
|
|
batch_latent,
|
|
cond,
|
|
indicies,
|
|
extra_args,
|
|
dataset_size,
|
|
bwd=True,
|
|
):
|
|
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
|
batch_sigmas, batch_noise, batch_latent, False
|
|
)
|
|
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
|
torch.zeros_like(batch_sigmas),
|
|
torch.zeros_like(batch_noise),
|
|
batch_latent,
|
|
False,
|
|
)
|
|
|
|
model_wrap.conds["positive"] = [cond[i] for i in indicies]
|
|
batch_extra_args = make_batch_extra_option_dict(
|
|
extra_args, indicies, full_size=dataset_size
|
|
)
|
|
|
|
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
|
x0_pred = model_wrap(
|
|
xt.requires_grad_(True),
|
|
batch_sigmas.requires_grad_(True),
|
|
**batch_extra_args,
|
|
)
|
|
loss = self.loss_fn(x0_pred.float(), x0.float())
|
|
if bwd:
|
|
bwd_loss = loss / self.grad_acc
|
|
if self.grad_scaler is not None:
|
|
self.grad_scaler.scale(bwd_loss).backward()
|
|
else:
|
|
bwd_loss.backward()
|
|
return loss
|
|
|
|
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
|
"""Generate random sigma values for a batch."""
|
|
batch_sigmas = [
|
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
torch.rand((1,)).item()
|
|
)
|
|
for _ in range(batch_size)
|
|
]
|
|
return torch.tensor(batch_sigmas).to(device)
|
|
|
|
def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
|
|
"""Execute one training step in bucket mode."""
|
|
# Sample bucket (weighted by size), then sample batch from bucket
|
|
bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
|
|
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
|
|
bucket_size = bucket_latent.shape[0]
|
|
bucket_offset = self.bucket_offsets[bucket_idx]
|
|
|
|
# Sample indices from this bucket (use all if bucket_size < batch_size)
|
|
actual_batch_size = min(self.batch_size, bucket_size)
|
|
relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist()
|
|
# Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index)
|
|
absolute_indices = [bucket_offset + idx for idx in relative_indices]
|
|
|
|
batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W)
|
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
|
batch_latent.device
|
|
)
|
|
batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
|
|
|
|
loss = self.fwd_bwd(
|
|
model_wrap,
|
|
batch_sigmas,
|
|
batch_noise,
|
|
batch_latent,
|
|
cond, # Use flattened cond with absolute indices
|
|
absolute_indices,
|
|
extra_args,
|
|
self.num_images,
|
|
bwd=True,
|
|
)
|
|
if self.loss_callback:
|
|
self.loss_callback(loss.item())
|
|
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
|
|
|
|
def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
|
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
|
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
|
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
|
batch_latent.device
|
|
)
|
|
batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
|
|
|
|
loss = self.fwd_bwd(
|
|
model_wrap,
|
|
batch_sigmas,
|
|
batch_noise,
|
|
batch_latent,
|
|
cond,
|
|
indicies,
|
|
extra_args,
|
|
dataset_size,
|
|
bwd=True,
|
|
)
|
|
if self.loss_callback:
|
|
self.loss_callback(loss.item())
|
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
|
|
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
|
|
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
|
|
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
|
total_loss = 0
|
|
for index in indicies:
|
|
single_latent = self.real_dataset[index].to(latent_image)
|
|
batch_noise = noisegen.generate_noise(
|
|
{"samples": single_latent}
|
|
).to(single_latent.device)
|
|
batch_sigmas = (
|
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
|
torch.rand((1,)).item()
|
|
)
|
|
)
|
|
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
|
loss = self.fwd_bwd(
|
|
model_wrap,
|
|
batch_sigmas,
|
|
batch_noise,
|
|
single_latent,
|
|
cond,
|
|
[index],
|
|
extra_args,
|
|
dataset_size,
|
|
bwd=False,
|
|
)
|
|
total_loss += loss
|
|
total_loss = total_loss / self.grad_acc / len(indicies)
|
|
if self.grad_scaler is not None:
|
|
self.grad_scaler.scale(total_loss).backward()
|
|
else:
|
|
total_loss.backward()
|
|
if self.loss_callback:
|
|
self.loss_callback(total_loss.item())
|
|
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
|
|
|
def sample(
|
|
self,
|
|
model_wrap,
|
|
sigmas,
|
|
extra_args,
|
|
callback,
|
|
noise,
|
|
latent_image=None,
|
|
denoise_mask=None,
|
|
disable_pbar=False,
|
|
):
|
|
model_wrap.conds = process_cond_list(model_wrap.conds)
|
|
cond = model_wrap.conds["positive"]
|
|
dataset_size = sigmas.size(0)
|
|
torch.cuda.empty_cache()
|
|
ui_pbar = ProgressBar(self.total_steps)
|
|
for i in (
|
|
pbar := trange(
|
|
self.total_steps,
|
|
desc="Training LoRA",
|
|
smoothing=0.01,
|
|
disable=not comfy.utils.PROGRESS_BAR_ENABLED,
|
|
)
|
|
):
|
|
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
|
self.seed + i * 1000
|
|
)
|
|
|
|
if self.bucket_latents is not None:
|
|
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
|
|
elif self.real_dataset is None:
|
|
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
|
else:
|
|
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
|
|
|
if (i + 1) % self.grad_acc == 0:
|
|
if self.grad_scaler is not None:
|
|
self.grad_scaler.unscale_(self.optimizer)
|
|
for param_groups in self.optimizer.param_groups:
|
|
for param in param_groups["params"]:
|
|
if param.grad is None:
|
|
continue
|
|
param.grad.data = param.grad.data.to(param.data.dtype)
|
|
if self.grad_scaler is not None:
|
|
self.grad_scaler.step(self.optimizer)
|
|
self.grad_scaler.update()
|
|
else:
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
ui_pbar.update(1)
|
|
torch.cuda.empty_cache()
|
|
return torch.zeros_like(latent_image)
|
|
|
|
|
|
class BiasDiff(torch.nn.Module):
|
|
def __init__(self, bias):
|
|
super().__init__()
|
|
self.bias = bias
|
|
|
|
def __call__(self, b):
|
|
org_dtype = b.dtype
|
|
return (b.to(self.bias) + self.bias).to(org_dtype)
|
|
|
|
def passive_memory_usage(self):
|
|
return self.bias.nelement() * self.bias.element_size()
|
|
|
|
def move_to(self, device):
|
|
self.to(device=device)
|
|
return self.passive_memory_usage()
|
|
|
|
|
|
def draw_loss_graph(loss_map, steps):
|
|
width, height = 500, 300
|
|
img = Image.new("RGB", (width, height), "white")
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
|
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()]
|
|
|
|
prev_point = (0, height - int(scaled_loss[0] * height))
|
|
for i, l in enumerate(scaled_loss[1:], start=1):
|
|
x = int(i / (steps - 1) * width)
|
|
y = height - int(l * height)
|
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
|
prev_point = (x, y)
|
|
|
|
return img
|
|
|
|
|
|
def find_all_highest_child_module_with_forward(
|
|
model: torch.nn.Module, result=None, name=None
|
|
):
|
|
if result is None:
|
|
result = []
|
|
elif hasattr(model, "forward") and not isinstance(
|
|
model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
|
|
):
|
|
result.append(model)
|
|
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
|
return result
|
|
name = name or "root"
|
|
for next_name, child in model.named_children():
|
|
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
|
|
return result
|
|
|
|
|
|
def find_modules_at_depth(
|
|
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
|
|
) -> list[nn.Module]:
|
|
"""
|
|
Find modules at a specific depth level for gradient checkpointing.
|
|
|
|
Args:
|
|
model: The model to search
|
|
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
|
|
result: Accumulator for results
|
|
current_depth: Current recursion depth
|
|
name: Current module name for logging
|
|
|
|
Returns:
|
|
List of modules at the target depth
|
|
"""
|
|
if result is None:
|
|
result = []
|
|
name = name or "root"
|
|
|
|
# Skip container modules (they don't have meaningful forward)
|
|
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
|
|
has_forward = hasattr(model, "forward") and not is_container
|
|
|
|
if has_forward:
|
|
current_depth += 1
|
|
if current_depth == depth:
|
|
result.append(model)
|
|
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
|
|
return result
|
|
|
|
# Recurse into children
|
|
for next_name, child in model.named_children():
|
|
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
|
|
|
|
return result
|
|
|
|
|
|
class OffloadCheckpointFunction(torch.autograd.Function):
|
|
"""
|
|
Gradient checkpointing that works with weight offloading.
|
|
|
|
Forward: no_grad -> compute -> weights can be freed
|
|
Backward: enable_grad -> recompute -> backward -> weights can be freed
|
|
|
|
For single input, single output modules (Linear, Conv*).
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x: torch.Tensor, forward_fn):
|
|
ctx.save_for_backward(x)
|
|
ctx.forward_fn = forward_fn
|
|
with torch.no_grad():
|
|
return forward_fn(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out: torch.Tensor):
|
|
x, = ctx.saved_tensors
|
|
forward_fn = ctx.forward_fn
|
|
|
|
# Clear context early
|
|
ctx.forward_fn = None
|
|
|
|
with torch.enable_grad():
|
|
x_detached = x.detach().requires_grad_(True)
|
|
y = forward_fn(x_detached)
|
|
y.backward(grad_out)
|
|
grad_x = x_detached.grad
|
|
|
|
# Explicit cleanup
|
|
del y, x_detached, forward_fn
|
|
|
|
return grad_x, None
|
|
|
|
|
|
def patch(m, offloading=False):
|
|
if not hasattr(m, "forward"):
|
|
return
|
|
org_forward = m.forward
|
|
|
|
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
|
|
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
|
def checkpointing_fwd(x):
|
|
return OffloadCheckpointFunction.apply(x, org_forward)
|
|
# Branch 2: Others -> standard checkpoint
|
|
else:
|
|
def fwd(args, kwargs):
|
|
return org_forward(*args, **kwargs)
|
|
|
|
def checkpointing_fwd(*args, **kwargs):
|
|
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
|
|
|
m.org_forward = org_forward
|
|
m.forward = checkpointing_fwd
|
|
|
|
|
|
def unpatch(m):
|
|
if hasattr(m, "org_forward"):
|
|
m.forward = m.org_forward
|
|
del m.org_forward
|
|
|
|
|
|
def _process_latents_bucket_mode(latents):
|
|
"""Process latents for bucket mode training.
|
|
|
|
Args:
|
|
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
|
|
|
|
Returns:
|
|
list of latent tensors
|
|
"""
|
|
bucket_latents = []
|
|
for latent_dict in latents:
|
|
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
|
|
return bucket_latents
|
|
|
|
|
|
def _process_latents_standard_mode(latents):
|
|
"""Process latents for standard (non-bucket) mode training.
|
|
|
|
Args:
|
|
latents: list of latent dicts or single latent dict
|
|
|
|
Returns:
|
|
Processed latents (tensor or list of tensors)
|
|
"""
|
|
if len(latents) == 1:
|
|
return latents[0]["samples"] # Single latent dict
|
|
|
|
latent_list = []
|
|
for latent in latents:
|
|
latent = latent["samples"]
|
|
bs = latent.shape[0]
|
|
if bs != 1:
|
|
for sub_latent in latent:
|
|
latent_list.append(sub_latent[None])
|
|
else:
|
|
latent_list.append(latent)
|
|
return latent_list
|
|
|
|
|
|
def _process_conditioning(positive):
|
|
"""Process conditioning - either single list or list of lists.
|
|
|
|
Args:
|
|
positive: list of conditioning
|
|
|
|
Returns:
|
|
Flattened conditioning list
|
|
"""
|
|
if len(positive) == 1:
|
|
return positive[0] # Single conditioning list
|
|
|
|
# Multiple conditioning lists - flatten
|
|
flat_positive = []
|
|
for cond in positive:
|
|
if isinstance(cond, list):
|
|
flat_positive.extend(cond)
|
|
else:
|
|
flat_positive.append(cond)
|
|
return flat_positive
|
|
|
|
|
|
def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
|
"""Convert latents to dtype and compute image counts.
|
|
|
|
Args:
|
|
latents: Latents (tensor, list of tensors, or bucket list)
|
|
dtype: Target dtype
|
|
bucket_mode: Whether bucket mode is enabled
|
|
|
|
Returns:
|
|
tuple: (processed_latents, num_images, multi_res)
|
|
"""
|
|
if bucket_mode:
|
|
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
|
|
latents = [t.to(dtype) for t in latents]
|
|
num_buckets = len(latents)
|
|
num_images = sum(t.shape[0] for t in latents)
|
|
multi_res = False # Not using multi_res path in bucket mode
|
|
|
|
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
|
for i, lat in enumerate(latents):
|
|
logging.debug(f" Bucket {i}: shape {lat.shape}")
|
|
return latents, num_images, multi_res
|
|
|
|
# Non-bucket mode
|
|
if isinstance(latents, list):
|
|
all_shapes = set()
|
|
latents = [t.to(dtype) for t in latents]
|
|
for latent in latents:
|
|
all_shapes.add(latent.shape)
|
|
logging.debug(f"Latent shapes: {all_shapes}")
|
|
if len(all_shapes) > 1:
|
|
multi_res = True
|
|
else:
|
|
multi_res = False
|
|
latents = torch.cat(latents, dim=0)
|
|
num_images = len(latents)
|
|
elif isinstance(latents, torch.Tensor):
|
|
latents = latents.to(dtype)
|
|
num_images = latents.shape[0]
|
|
multi_res = False
|
|
else:
|
|
logging.error(f"Invalid latents type: {type(latents)}")
|
|
num_images = 0
|
|
multi_res = False
|
|
|
|
return latents, num_images, multi_res
|
|
|
|
|
|
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
|
|
"""Validate conditioning count matches image count, expand if needed.
|
|
|
|
Args:
|
|
positive: Conditioning list
|
|
num_images: Number of images
|
|
bucket_mode: Whether bucket mode is enabled
|
|
|
|
Returns:
|
|
Validated/expanded conditioning list
|
|
|
|
Raises:
|
|
ValueError: If conditioning count doesn't match image count
|
|
"""
|
|
if bucket_mode:
|
|
return positive # Skip validation in bucket mode
|
|
|
|
logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
|
if len(positive) == 1 and num_images > 1:
|
|
return positive * num_images
|
|
elif len(positive) != num_images:
|
|
raise ValueError(
|
|
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
|
)
|
|
return positive
|
|
|
|
|
|
def _load_existing_lora(existing_lora):
|
|
"""Load existing LoRA weights if provided.
|
|
|
|
Args:
|
|
existing_lora: LoRA filename or "[None]"
|
|
|
|
Returns:
|
|
tuple: (existing_weights dict, existing_steps int)
|
|
"""
|
|
if existing_lora == "[None]":
|
|
return {}, 0
|
|
|
|
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
|
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
|
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
|
existing_weights = {}
|
|
if lora_path:
|
|
existing_weights = comfy.utils.load_torch_file(lora_path)
|
|
return existing_weights, existing_steps
|
|
|
|
|
|
def _create_weight_adapter(
|
|
module, module_name, existing_weights, algorithm, lora_dtype, rank
|
|
):
|
|
"""Create a weight adapter for a module with weight.
|
|
|
|
Args:
|
|
module: The module to create adapter for
|
|
module_name: Name of the module
|
|
existing_weights: Dict of existing LoRA weights
|
|
algorithm: Algorithm name for new adapters
|
|
lora_dtype: dtype for LoRA weights
|
|
rank: Rank for new LoRA adapters
|
|
|
|
Returns:
|
|
tuple: (train_adapter, lora_params dict)
|
|
"""
|
|
key = f"{module_name}.weight"
|
|
shape = module.weight.shape
|
|
lora_params = {}
|
|
|
|
logging.debug(f"Creating weight adapter for {key} with shape {shape}")
|
|
|
|
if len(shape) >= 2:
|
|
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
|
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
|
|
|
# Try to load existing adapter
|
|
existing_adapter = None
|
|
for adapter_cls in adapters:
|
|
existing_adapter = adapter_cls.load(
|
|
module_name, existing_weights, alpha, dora_scale
|
|
)
|
|
if existing_adapter is not None:
|
|
break
|
|
|
|
if existing_adapter is None:
|
|
adapter_cls = adapter_maps[algorithm]
|
|
|
|
if existing_adapter is not None:
|
|
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
|
else:
|
|
# Use LoRA with alpha=1.0 by default
|
|
train_adapter = adapter_cls.create_train(
|
|
module.weight, rank=rank, alpha=1.0
|
|
).to(lora_dtype)
|
|
|
|
for name, parameter in train_adapter.named_parameters():
|
|
lora_params[f"{module_name}.{name}"] = parameter
|
|
|
|
return train_adapter.train().requires_grad_(True), lora_params
|
|
else:
|
|
# 1D weight - use BiasDiff
|
|
diff = torch.nn.Parameter(
|
|
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
|
|
)
|
|
diff_module = BiasDiff(diff).train().requires_grad_(True)
|
|
lora_params[f"{module_name}.diff"] = diff
|
|
return diff_module, lora_params
|
|
|
|
|
|
def _create_bias_adapter(module, module_name, lora_dtype):
|
|
"""Create a bias adapter for a module with bias.
|
|
|
|
Args:
|
|
module: The module with bias
|
|
module_name: Name of the module
|
|
lora_dtype: dtype for LoRA weights
|
|
|
|
Returns:
|
|
tuple: (bias_module, lora_params dict)
|
|
"""
|
|
bias = torch.nn.Parameter(
|
|
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
|
|
)
|
|
bias_module = BiasDiff(bias).train().requires_grad_(True)
|
|
lora_params = {f"{module_name}.diff_b": bias}
|
|
return bias_module, lora_params
|
|
|
|
|
|
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
|
|
"""Setup all LoRA adapters on the model.
|
|
|
|
Args:
|
|
mp: Model patcher
|
|
existing_weights: Dict of existing LoRA weights
|
|
algorithm: Algorithm name for new adapters
|
|
lora_dtype: dtype for LoRA weights
|
|
rank: Rank for new LoRA adapters
|
|
|
|
Returns:
|
|
tuple: (lora_sd dict, all_weight_adapters list)
|
|
"""
|
|
lora_sd = {}
|
|
all_weight_adapters = []
|
|
|
|
for n, m in mp.model.named_modules():
|
|
if hasattr(m, "weight_function"):
|
|
if m.weight is not None:
|
|
adapter, params = _create_weight_adapter(
|
|
m, n, existing_weights, algorithm, lora_dtype, rank
|
|
)
|
|
lora_sd.update(params)
|
|
key = f"{n}.weight"
|
|
mp.add_weight_wrapper(key, adapter)
|
|
all_weight_adapters.append(adapter)
|
|
|
|
if hasattr(m, "bias") and m.bias is not None:
|
|
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
|
|
lora_sd.update(bias_params)
|
|
key = f"{n}.bias"
|
|
mp.add_weight_wrapper(key, bias_adapter)
|
|
all_weight_adapters.append(bias_adapter)
|
|
|
|
return lora_sd, all_weight_adapters
|
|
|
|
|
|
def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank):
|
|
"""Setup LoRA adapters in bypass mode.
|
|
|
|
In bypass mode:
|
|
- Weight adapters (lora/lokr/oft) use bypass injection (forward hook)
|
|
- Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification)
|
|
|
|
This is useful when the base model weights are quantized and cannot be
|
|
directly modified.
|
|
|
|
Args:
|
|
mp: Model patcher
|
|
existing_weights: Dict of existing LoRA weights
|
|
algorithm: Algorithm name for new adapters
|
|
lora_dtype: dtype for LoRA weights
|
|
rank: Rank for new LoRA adapters
|
|
|
|
Returns:
|
|
tuple: (lora_sd dict, all_weight_adapters list, bypass_manager)
|
|
"""
|
|
lora_sd = {}
|
|
all_weight_adapters = []
|
|
bypass_manager = BypassInjectionManager()
|
|
|
|
for n, m in mp.model.named_modules():
|
|
if hasattr(m, "weight_function"):
|
|
if m.weight is not None:
|
|
adapter, params = _create_weight_adapter(
|
|
m, n, existing_weights, algorithm, lora_dtype, rank
|
|
)
|
|
lora_sd.update(params)
|
|
all_weight_adapters.append(adapter)
|
|
|
|
key = f"{n}.weight"
|
|
# BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass
|
|
# Only use bypass for adapters that have h() method (lora/lokr/oft)
|
|
if isinstance(adapter, BiasDiff):
|
|
mp.add_weight_wrapper(key, adapter)
|
|
logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}")
|
|
else:
|
|
bypass_manager.add_adapter(key, adapter, strength=1.0)
|
|
logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}")
|
|
|
|
if hasattr(m, "bias") and m.bias is not None:
|
|
# Bias adapters still use weight wrapper (bias is usually not quantized)
|
|
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
|
|
lora_sd.update(bias_params)
|
|
key = f"{n}.bias"
|
|
mp.add_weight_wrapper(key, bias_adapter)
|
|
all_weight_adapters.append(bias_adapter)
|
|
logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}")
|
|
|
|
return lora_sd, all_weight_adapters, bypass_manager
|
|
|
|
|
|
def _create_optimizer(optimizer_name, parameters, learning_rate):
|
|
"""Create optimizer based on name.
|
|
|
|
Args:
|
|
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
|
|
parameters: Parameters to optimize
|
|
learning_rate: Learning rate
|
|
|
|
Returns:
|
|
Optimizer instance
|
|
"""
|
|
if optimizer_name == "Adam":
|
|
return torch.optim.Adam(parameters, lr=learning_rate)
|
|
elif optimizer_name == "AdamW":
|
|
return torch.optim.AdamW(parameters, lr=learning_rate)
|
|
elif optimizer_name == "SGD":
|
|
return torch.optim.SGD(parameters, lr=learning_rate)
|
|
elif optimizer_name == "RMSprop":
|
|
return torch.optim.RMSprop(parameters, lr=learning_rate)
|
|
|
|
|
|
def _create_loss_function(loss_function_name):
|
|
"""Create loss function based on name.
|
|
|
|
Args:
|
|
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
|
|
|
|
Returns:
|
|
Loss function instance
|
|
"""
|
|
if loss_function_name == "MSE":
|
|
return torch.nn.MSELoss()
|
|
elif loss_function_name == "L1":
|
|
return torch.nn.L1Loss()
|
|
elif loss_function_name == "Huber":
|
|
return torch.nn.HuberLoss()
|
|
elif loss_function_name == "SmoothL1":
|
|
return torch.nn.SmoothL1Loss()
|
|
|
|
|
|
def _run_training_loop(
|
|
guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res
|
|
):
|
|
"""Execute the training loop.
|
|
|
|
Args:
|
|
guider: The guider object
|
|
train_sampler: The training sampler
|
|
latents: Latent tensors
|
|
num_images: Number of images
|
|
seed: Random seed
|
|
bucket_mode: Whether bucket mode is enabled
|
|
multi_res: Whether multi-resolution mode is enabled
|
|
"""
|
|
sigmas = torch.tensor(range(num_images))
|
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
|
|
|
if bucket_mode:
|
|
# Use first bucket's first latent as dummy for guider
|
|
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
|
|
guider.sample(
|
|
noise.generate_noise({"samples": dummy_latent}),
|
|
dummy_latent,
|
|
train_sampler,
|
|
sigmas,
|
|
seed=noise.seed,
|
|
)
|
|
elif multi_res:
|
|
# use first latent as dummy latent if multi_res
|
|
latents = latents[0].repeat(num_images, 1, 1, 1)
|
|
guider.sample(
|
|
noise.generate_noise({"samples": latents}),
|
|
latents,
|
|
train_sampler,
|
|
sigmas,
|
|
seed=noise.seed,
|
|
)
|
|
else:
|
|
guider.sample(
|
|
noise.generate_noise({"samples": latents}),
|
|
latents,
|
|
train_sampler,
|
|
sigmas,
|
|
seed=noise.seed,
|
|
)
|
|
|
|
|
|
class TrainLoraNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="TrainLoraNode",
|
|
display_name="Train LoRA",
|
|
category="training",
|
|
is_experimental=True,
|
|
is_input_list=True, # All inputs become lists
|
|
inputs=[
|
|
io.Model.Input("model", tooltip="The model to train the LoRA on."),
|
|
io.Latent.Input(
|
|
"latents",
|
|
tooltip="The Latents to use for training, serve as dataset/input of the model.",
|
|
),
|
|
io.Conditioning.Input(
|
|
"positive", tooltip="The positive conditioning to use for training."
|
|
),
|
|
io.Int.Input(
|
|
"batch_size",
|
|
default=1,
|
|
min=1,
|
|
max=10000,
|
|
tooltip="The batch size to use for training.",
|
|
),
|
|
io.Int.Input(
|
|
"grad_accumulation_steps",
|
|
default=1,
|
|
min=1,
|
|
max=1024,
|
|
tooltip="The number of gradient accumulation steps to use for training.",
|
|
),
|
|
io.Int.Input(
|
|
"steps",
|
|
default=16,
|
|
min=1,
|
|
max=100000,
|
|
tooltip="The number of steps to train the LoRA for.",
|
|
),
|
|
io.Float.Input(
|
|
"learning_rate",
|
|
default=0.0005,
|
|
min=0.0000001,
|
|
max=1.0,
|
|
step=0.0000001,
|
|
tooltip="The learning rate to use for training.",
|
|
),
|
|
io.Int.Input(
|
|
"rank",
|
|
default=8,
|
|
min=1,
|
|
max=128,
|
|
tooltip="The rank of the LoRA layers.",
|
|
),
|
|
io.Combo.Input(
|
|
"optimizer",
|
|
options=["AdamW", "Adam", "SGD", "RMSprop"],
|
|
default="AdamW",
|
|
tooltip="The optimizer to use for training.",
|
|
),
|
|
io.Combo.Input(
|
|
"loss_function",
|
|
options=["MSE", "L1", "Huber", "SmoothL1"],
|
|
default="MSE",
|
|
tooltip="The loss function to use for training.",
|
|
),
|
|
io.Int.Input(
|
|
"seed",
|
|
default=0,
|
|
min=0,
|
|
max=0xFFFFFFFFFFFFFFFF,
|
|
tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
|
|
),
|
|
io.Combo.Input(
|
|
"training_dtype",
|
|
options=["bf16", "fp32", "none"],
|
|
default="bf16",
|
|
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
|
|
),
|
|
io.Combo.Input(
|
|
"lora_dtype",
|
|
options=["bf16", "fp32"],
|
|
default="bf16",
|
|
tooltip="The dtype to use for lora.",
|
|
),
|
|
io.Boolean.Input(
|
|
"quantized_backward",
|
|
default=False,
|
|
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
|
),
|
|
io.Combo.Input(
|
|
"algorithm",
|
|
options=list(adapter_maps.keys()),
|
|
default=list(adapter_maps.keys())[0],
|
|
tooltip="The algorithm to use for training.",
|
|
),
|
|
io.Boolean.Input(
|
|
"gradient_checkpointing",
|
|
default=True,
|
|
tooltip="Use gradient checkpointing for training.",
|
|
),
|
|
io.Int.Input(
|
|
"checkpoint_depth",
|
|
default=1,
|
|
min=1,
|
|
max=5,
|
|
tooltip="Depth level for gradient checkpointing.",
|
|
),
|
|
io.Boolean.Input(
|
|
"offloading",
|
|
default=False,
|
|
tooltip="Offload model weights to CPU during training to save GPU memory.",
|
|
),
|
|
io.Combo.Input(
|
|
"existing_lora",
|
|
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
|
default="[None]",
|
|
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
|
),
|
|
io.Boolean.Input(
|
|
"bucket_mode",
|
|
default=False,
|
|
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
|
),
|
|
io.Boolean.Input(
|
|
"bypass_mode",
|
|
default=False,
|
|
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Custom("LORA_MODEL").Output(
|
|
display_name="lora", tooltip="LoRA weights"
|
|
),
|
|
io.Custom("LOSS_MAP").Output(
|
|
display_name="loss_map", tooltip="Loss history"
|
|
),
|
|
io.Int.Output(display_name="steps", tooltip="Total training steps"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(
|
|
cls,
|
|
model,
|
|
latents,
|
|
positive,
|
|
batch_size,
|
|
steps,
|
|
grad_accumulation_steps,
|
|
learning_rate,
|
|
rank,
|
|
optimizer,
|
|
loss_function,
|
|
seed,
|
|
training_dtype,
|
|
lora_dtype,
|
|
quantized_backward,
|
|
algorithm,
|
|
gradient_checkpointing,
|
|
checkpoint_depth,
|
|
offloading,
|
|
existing_lora,
|
|
bucket_mode,
|
|
bypass_mode,
|
|
):
|
|
# Extract scalars from lists (due to is_input_list=True)
|
|
model = model[0]
|
|
batch_size = batch_size[0]
|
|
steps = steps[0]
|
|
grad_accumulation_steps = grad_accumulation_steps[0]
|
|
learning_rate = learning_rate[0]
|
|
rank = rank[0]
|
|
optimizer_name = optimizer[0]
|
|
loss_function_name = loss_function[0]
|
|
seed = seed[0]
|
|
training_dtype = training_dtype[0]
|
|
lora_dtype = lora_dtype[0]
|
|
quantized_backward = quantized_backward[0]
|
|
algorithm = algorithm[0]
|
|
gradient_checkpointing = gradient_checkpointing[0]
|
|
offloading = offloading[0]
|
|
checkpoint_depth = checkpoint_depth[0]
|
|
existing_lora = existing_lora[0]
|
|
bucket_mode = bucket_mode[0]
|
|
bypass_mode = bypass_mode[0]
|
|
|
|
comfy.model_management.training_fp8_bwd = quantized_backward
|
|
|
|
# Process latents based on mode
|
|
if bucket_mode:
|
|
latents = _process_latents_bucket_mode(latents)
|
|
else:
|
|
latents = _process_latents_standard_mode(latents)
|
|
|
|
# Process conditioning
|
|
positive = _process_conditioning(positive)
|
|
|
|
# Setup model and dtype
|
|
mp = model.clone()
|
|
use_grad_scaler = False
|
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
|
if training_dtype != "none":
|
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
|
mp.set_model_compute_dtype(dtype)
|
|
else:
|
|
# Detect model's native dtype for autocast
|
|
model_dtype = mp.model.get_dtype()
|
|
if model_dtype == torch.float16:
|
|
dtype = torch.float16
|
|
# GradScaler only supports float16 gradients, not bfloat16.
|
|
# Only enable it when lora params will also be in float16.
|
|
if lora_dtype != torch.bfloat16:
|
|
use_grad_scaler = True
|
|
# Warn about fp16 accumulation instability during training
|
|
if PerformanceFeature.Fp16Accumulation in args.fast:
|
|
logging.warning(
|
|
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
|
"This combination can be numerically unstable during training and may cause NaN values. "
|
|
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
|
)
|
|
else:
|
|
# For fp8, bf16, or other dtypes, use bf16 autocast
|
|
dtype = torch.bfloat16
|
|
|
|
# Prepare latents and compute counts
|
|
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
|
latents, num_images, multi_res = _prepare_latents_and_count(
|
|
latents, latents_dtype, bucket_mode
|
|
)
|
|
|
|
# Validate and expand conditioning
|
|
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
|
|
|
with torch.inference_mode(False):
|
|
# Setup models for training
|
|
mp.model.requires_grad_(False)
|
|
|
|
# Load existing LoRA weights if provided
|
|
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
|
|
|
# Setup LoRA adapters
|
|
bypass_manager = None
|
|
if bypass_mode:
|
|
logging.debug("Using bypass mode for training")
|
|
lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass(
|
|
mp, existing_weights, algorithm, lora_dtype, rank
|
|
)
|
|
else:
|
|
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
|
mp, existing_weights, algorithm, lora_dtype, rank
|
|
)
|
|
|
|
# Create optimizer and loss function
|
|
optimizer = _create_optimizer(
|
|
optimizer_name, lora_sd.values(), learning_rate
|
|
)
|
|
criterion = _create_loss_function(loss_function_name)
|
|
|
|
# Setup gradient checkpointing
|
|
if gradient_checkpointing:
|
|
modules_to_patch = find_modules_at_depth(
|
|
mp.model.diffusion_model, depth=checkpoint_depth
|
|
)
|
|
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
|
|
for m in modules_to_patch:
|
|
patch(m, offloading=offloading)
|
|
|
|
torch.cuda.empty_cache()
|
|
# With force_full_load=False we should be able to have offloading
|
|
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
|
comfy.model_management.load_models_gpu(
|
|
[mp], memory_required=1e20, force_full_load=not offloading
|
|
)
|
|
torch.cuda.empty_cache()
|
|
|
|
# Setup loss tracking
|
|
loss_map = {"loss": []}
|
|
|
|
def loss_callback(loss):
|
|
loss_map["loss"].append(loss)
|
|
|
|
# Create sampler
|
|
if bucket_mode:
|
|
train_sampler = TrainSampler(
|
|
criterion,
|
|
optimizer,
|
|
loss_callback=loss_callback,
|
|
batch_size=batch_size,
|
|
grad_acc=grad_accumulation_steps,
|
|
total_steps=steps * grad_accumulation_steps,
|
|
seed=seed,
|
|
training_dtype=dtype,
|
|
bucket_latents=latents,
|
|
use_grad_scaler=use_grad_scaler,
|
|
)
|
|
else:
|
|
train_sampler = TrainSampler(
|
|
criterion,
|
|
optimizer,
|
|
loss_callback=loss_callback,
|
|
batch_size=batch_size,
|
|
grad_acc=grad_accumulation_steps,
|
|
total_steps=steps * grad_accumulation_steps,
|
|
seed=seed,
|
|
training_dtype=dtype,
|
|
real_dataset=latents if multi_res else None,
|
|
use_grad_scaler=use_grad_scaler,
|
|
)
|
|
|
|
# Setup guider
|
|
guider = TrainGuider(mp, offloading=offloading)
|
|
guider.set_conds(positive)
|
|
|
|
# Inject bypass hooks if bypass mode is enabled
|
|
bypass_injections = None
|
|
if bypass_manager is not None:
|
|
bypass_injections = bypass_manager.create_injections(mp.model)
|
|
for injection in bypass_injections:
|
|
injection.inject(mp)
|
|
logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks")
|
|
|
|
# Run training loop
|
|
try:
|
|
comfy.model_management.in_training = True
|
|
_run_training_loop(
|
|
guider,
|
|
train_sampler,
|
|
latents,
|
|
num_images,
|
|
seed,
|
|
bucket_mode,
|
|
multi_res,
|
|
)
|
|
finally:
|
|
comfy.model_management.in_training = False
|
|
# Eject bypass hooks if they were injected
|
|
if bypass_injections is not None:
|
|
for injection in bypass_injections:
|
|
injection.eject(mp)
|
|
logging.debug("[BypassMode] Ejected bypass hooks")
|
|
for m in mp.model.modules():
|
|
unpatch(m)
|
|
del train_sampler, optimizer
|
|
|
|
for param in lora_sd:
|
|
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
|
|
|
|
for adapter in all_weight_adapters:
|
|
adapter.requires_grad_(False)
|
|
del adapter
|
|
del all_weight_adapters
|
|
|
|
# mp in train node is highly specialized for training
|
|
# use it in inference will result in bad behavior so we don't return it
|
|
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
|
|
|
|
|
class LoraModelLoader(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="LoraModelLoader",
|
|
display_name="Load LoRA Model",
|
|
category="loaders",
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.Model.Input(
|
|
"model", tooltip="The diffusion model the LoRA will be applied to."
|
|
),
|
|
io.Custom("LORA_MODEL").Input(
|
|
"lora", tooltip="The LoRA model to apply to the diffusion model."
|
|
),
|
|
io.Float.Input(
|
|
"strength_model",
|
|
default=1.0,
|
|
min=-100.0,
|
|
max=100.0,
|
|
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
|
),
|
|
io.Boolean.Input(
|
|
"bypass",
|
|
default=False,
|
|
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Model.Output(
|
|
display_name="model", tooltip="The modified diffusion model."
|
|
),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model, lora, strength_model, bypass=False):
|
|
if strength_model == 0:
|
|
return io.NodeOutput(model)
|
|
|
|
if bypass:
|
|
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
|
|
model, None, lora, strength_model, 0
|
|
)
|
|
else:
|
|
model_lora, _ = comfy.sd.load_lora_for_models(
|
|
model, None, lora, strength_model, 0
|
|
)
|
|
return io.NodeOutput(model_lora)
|
|
|
|
|
|
class SaveLoRA(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SaveLoRA",
|
|
search_aliases=["export lora"],
|
|
display_name="Save LoRA Weights",
|
|
category="loaders",
|
|
is_experimental=True,
|
|
is_output_node=True,
|
|
inputs=[
|
|
io.Custom("LORA_MODEL").Input(
|
|
"lora",
|
|
tooltip="The LoRA model to save. Do not use the model with LoRA layers.",
|
|
),
|
|
io.String.Input(
|
|
"prefix",
|
|
default="loras/ComfyUI_trained_lora",
|
|
tooltip="The prefix to use for the saved LoRA file.",
|
|
),
|
|
io.Int.Input(
|
|
"steps",
|
|
optional=True,
|
|
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
|
|
),
|
|
],
|
|
outputs=[],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, lora, prefix, steps=None):
|
|
output_dir = folder_paths.get_output_directory()
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = (
|
|
folder_paths.get_save_image_path(prefix, output_dir)
|
|
)
|
|
if steps is None:
|
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
|
else:
|
|
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
|
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
|
safetensors.torch.save_file(lora, output_checkpoint)
|
|
return io.NodeOutput()
|
|
|
|
|
|
class LossGraphNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="LossGraphNode",
|
|
search_aliases=["training chart", "training visualization", "plot loss"],
|
|
display_name="Plot Loss Graph",
|
|
category="training",
|
|
is_experimental=True,
|
|
is_output_node=True,
|
|
inputs=[
|
|
io.Custom("LOSS_MAP").Input(
|
|
"loss", tooltip="Loss map from training node."
|
|
),
|
|
io.String.Input(
|
|
"filename_prefix",
|
|
default="loss_graph",
|
|
tooltip="Prefix for the saved loss graph image.",
|
|
),
|
|
],
|
|
outputs=[],
|
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None):
|
|
loss_values = loss["loss"]
|
|
width, height = 800, 480
|
|
margin = 40
|
|
|
|
img = Image.new(
|
|
"RGB", (width + margin, height + margin), "white"
|
|
) # Extend canvas
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
min_loss, max_loss = min(loss_values), max(loss_values)
|
|
scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values]
|
|
|
|
steps = len(loss_values)
|
|
|
|
prev_point = (margin, height - int(scaled_loss[0] * height))
|
|
for i, l in enumerate(scaled_loss[1:], start=1):
|
|
x = margin + int(i / steps * width) # Scale X properly
|
|
y = height - int(l * height)
|
|
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
|
prev_point = (x, y)
|
|
|
|
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
|
draw.line(
|
|
[(margin, height), (width + margin, height)], fill="black", width=2
|
|
) # X-axis
|
|
|
|
font = None
|
|
try:
|
|
font = ImageFont.truetype("arial.ttf", 12)
|
|
except IOError:
|
|
font = ImageFont.load_default()
|
|
|
|
# Add axis labels
|
|
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
|
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
|
|
|
# Add min/max loss values
|
|
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
|
draw.text(
|
|
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
|
)
|
|
|
|
# Convert PIL image to tensor for PreviewImage
|
|
img_array = np.array(img).astype(np.float32) / 255.0
|
|
img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3]
|
|
|
|
# Return preview UI
|
|
return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls))
|
|
|
|
|
|
# ========== Extension Setup ==========
|
|
|
|
|
|
class TrainingExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
TrainLoraNode,
|
|
LoraModelLoader,
|
|
SaveLoRA,
|
|
LossGraphNode,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> TrainingExtension:
|
|
return TrainingExtension()
|