import datetime import json import logging import os import numpy as np import safetensors import torch from PIL import Image, ImageDraw, ImageFont from PIL.PngImagePlugin import PngInfo import torch.utils.checkpoint import tqdm import comfy.samplers import comfy.sd import comfy.utils import comfy.model_management import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers from comfy.cli_args import args from comfy.comfy_types.node_typing import IO from comfy.weight_adapter import adapters, adapter_maps def make_batch_extra_option_dict(d, indicies, full_size=None): new_dict = {} for k, v in d.items(): newv = v if isinstance(v, dict): newv = make_batch_extra_option_dict(v, indicies, full_size=full_size) elif isinstance(v, torch.Tensor): if full_size is None or v.size(0) == full_size: newv = v[indicies] elif isinstance(v, (list, tuple)) and len(v) == full_size: newv = [v[i] for i in indicies] new_dict[k] = newv return new_dict def process_cond_list(d, prefix=""): if hasattr(d, "__iter__") and not hasattr(d, "items"): for index, item in enumerate(d): process_cond_list(item, f"{prefix}.{index}") return d elif hasattr(d, "items"): for k, v in list(d.items()): if isinstance(v, dict): process_cond_list(v, f"{prefix}.{k}") elif isinstance(v, torch.Tensor): d[k] = v.clone() elif isinstance(v, (list, tuple)): for index, item in enumerate(v): process_cond_list(item, f"{prefix}.{k}.{index}") return d class TrainSampler(comfy.samplers.Sampler): def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16, real_dataset=None): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback self.batch_size = batch_size self.total_steps = total_steps self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype self.real_dataset: list[torch.Tensor] | None = real_dataset def fwd_bwd(self, model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size): xt = model_wrap.inner_model.model_sampling.noise_scaling( batch_sigmas, batch_noise, batch_latent, False ) x0 = model_wrap.inner_model.model_sampling.noise_scaling( torch.zeros_like(batch_sigmas), torch.zeros_like(batch_noise), batch_latent, False ) model_wrap.conds["positive"] = [ cond[i] for i in indicies ] batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) with torch.autocast(xt.device.type, dtype=self.training_dtype): x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) loss = self.loss_fn(x0_pred, x0) loss.backward() return loss def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): model_wrap.conds = process_cond_list(model_wrap.conds) cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache() for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() if self.real_dataset is None: batch_latent = torch.stack([latent_image[i] for i in indicies]) batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) batch_sigmas = [ model_wrap.inner_model.model_sampling.percent_to_sigma( torch.rand((1,)).item() ) for _ in range(min(self.batch_size, dataset_size)) ] batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size) if self.loss_callback: self.loss_callback(loss.item()) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) else: total_loss = 0 for index in indicies: single_latent = self.real_dataset[index].to(latent_image) batch_noise = noisegen.generate_noise({"samples": single_latent}).to(single_latent.device) batch_sigmas = model_wrap.inner_model.model_sampling.percent_to_sigma( torch.rand((1,)).item() ) batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, single_latent, cond, [index], extra_args, dataset_size) total_loss += loss.item() total_loss /= len(indicies) if self.loss_callback: self.loss_callback(loss.item()) pbar.set_postfix({"loss": f"{total_loss/(index+1):.4f}"}) if (i+1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() torch.cuda.empty_cache() return torch.zeros_like(latent_image) class BiasDiff(torch.nn.Module): def __init__(self, bias): super().__init__() self.bias = bias def __call__(self, b): org_dtype = b.dtype return (b.to(self.bias) + self.bias).to(org_dtype) def passive_memory_usage(self): return self.bias.nelement() * self.bias.element_size() def move_to(self, device): self.to(device=device) return self.passive_memory_usage() def draw_loss_graph(loss_map, steps): width, height = 500, 300 img = Image.new("RGB", (width, height), "white") draw = ImageDraw.Draw(img) min_loss, max_loss = min(loss_map.values()), max(loss_map.values()) scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()] prev_point = (0, height - int(scaled_loss[0] * height)) for i, l in enumerate(scaled_loss[1:], start=1): x = int(i / (steps - 1) * width) y = height - int(l * height) draw.line([prev_point, (x, y)], fill="blue", width=2) prev_point = (x, y) return img def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): if result is None: result = [] elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): result.append(model) logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") return result name = name or "root" for next_name, child in model.named_children(): find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}") return result def patch(m): if not hasattr(m, "forward"): return org_forward = m.forward def fwd(args, kwargs): return org_forward(*args, **kwargs) def checkpointing_fwd(*args, **kwargs): return torch.utils.checkpoint.checkpoint( fwd, args, kwargs, use_reentrant=False ) m.org_forward = org_forward m.forward = checkpointing_fwd def unpatch(m): if hasattr(m, "org_forward"): m.forward = m.org_forward del m.org_forward class TrainLoraNode: @classmethod def INPUT_TYPES(s): return { "required": { "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), "latents": ( "LATENT", { "tooltip": "The Latents to use for training, serve as dataset/input of the model." }, ), "positive": ( IO.CONDITIONING, {"tooltip": "The positive conditioning to use for training."}, ), "batch_size": ( IO.INT, { "default": 1, "min": 1, "max": 10000, "step": 1, "tooltip": "The batch size to use for training.", }, ), "grad_accumulation_steps": ( IO.INT, { "default": 1, "min": 1, "max": 1024, "step": 1, "tooltip": "The number of gradient accumulation steps to use for training.", } ), "steps": ( IO.INT, { "default": 16, "min": 1, "max": 100000, "tooltip": "The number of steps to train the LoRA for.", }, ), "learning_rate": ( IO.FLOAT, { "default": 0.0005, "min": 0.0000001, "max": 1.0, "step": 0.000001, "tooltip": "The learning rate to use for training.", }, ), "rank": ( IO.INT, { "default": 8, "min": 1, "max": 128, "tooltip": "The rank of the LoRA layers.", }, ), "optimizer": ( ["AdamW", "Adam", "SGD", "RMSprop"], { "default": "AdamW", "tooltip": "The optimizer to use for training.", }, ), "loss_function": ( ["MSE", "L1", "Huber", "SmoothL1"], { "default": "MSE", "tooltip": "The loss function to use for training.", }, ), "seed": ( IO.INT, { "default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", }, ), "training_dtype": ( ["bf16", "fp32"], {"default": "bf16", "tooltip": "The dtype to use for training."}, ), "lora_dtype": ( ["bf16", "fp32"], {"default": "bf16", "tooltip": "The dtype to use for lora."}, ), "algorithm": ( list(adapter_maps.keys()), {"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."}, ), "gradient_checkpointing": ( IO.BOOLEAN, { "default": True, "tooltip": "Use gradient checkpointing for training.", } ), "existing_lora": ( folder_paths.get_filename_list("loras") + ["[None]"], { "default": "[None]", "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", }, ), }, } RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") FUNCTION = "train" CATEGORY = "training" EXPERIMENTAL = True def train( self, model, latents, positive, batch_size, steps, grad_accumulation_steps, learning_rate, rank, optimizer, loss_function, seed, training_dtype, lora_dtype, algorithm, gradient_checkpointing, existing_lora, ): mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) # latents here can be list of different size latent or one large batch latents = latents["samples"] if isinstance(latents, list): all_shapes = set() latents = [t.to(dtype) for t in latents] for latent in latents: all_shapes.add(latent.shape) logging.info(f"Latent shapes: {all_shapes}") if len(all_shapes) > 1: multi_res = True else: multi_res = False latents = torch.cat(latents, dim=0) num_images = len(latents) elif isinstance(latents, list): latents = latents["samples"].to(dtype) num_images = latents.shape[0] else: logging.error(f"Invalid latents type: {type(latents)}") logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: positive = positive * num_images elif len(positive) != num_images: raise ValueError( f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." ) with torch.inference_mode(False): lora_sd = {} generator = torch.Generator() generator.manual_seed(seed) # Load existing LoRA weights if provided existing_weights = {} existing_steps = 0 if existing_lora != "[None]": lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) # Extract steps from filename like "trained_lora_10_steps_20250225_203716" existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) if lora_path: existing_weights = comfy.utils.load_torch_file(lora_path) all_weight_adapters = [] for n, m in mp.model.named_modules(): if hasattr(m, "weight_function"): if m.weight is not None: key = "{}.weight".format(n) shape = m.weight.shape if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) dora_scale = existing_weights.get( f"{key}.dora_scale", None ) for adapter_cls in adapters: existing_adapter = adapter_cls.load( n, existing_weights, alpha, dora_scale ) if existing_adapter is not None: break else: existing_adapter = None adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: train_adapter = existing_adapter.to_train().to(lora_dtype) else: # Use LoRA with alpha=1.0 by default train_adapter = adapter_cls.create_train( m.weight, rank=rank, alpha=1.0 ).to(lora_dtype) for name, parameter in train_adapter.named_parameters(): lora_sd[f"{n}.{name}"] = parameter mp.add_weight_wrapper(key, train_adapter) all_weight_adapters.append(train_adapter) else: diff = torch.nn.Parameter( torch.zeros( m.weight.shape, dtype=lora_dtype, requires_grad=True ) ) diff_module = BiasDiff(diff) mp.add_weight_wrapper(key, BiasDiff(diff)) all_weight_adapters.append(diff_module) lora_sd["{}.diff".format(n)] = diff if hasattr(m, "bias") and m.bias is not None: key = "{}.bias".format(n) bias = torch.nn.Parameter( torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) ) bias_module = BiasDiff(bias) lora_sd["{}.diff_b".format(n)] = bias mp.add_weight_wrapper(key, BiasDiff(bias)) all_weight_adapters.append(bias_module) if optimizer == "Adam": optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) elif optimizer == "AdamW": optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) elif optimizer == "SGD": optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) elif optimizer == "RMSprop": optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) # Setup loss function based on selection if loss_function == "MSE": criterion = torch.nn.MSELoss() elif loss_function == "L1": criterion = torch.nn.L1Loss() elif loss_function == "Huber": criterion = torch.nn.HuberLoss() elif loss_function == "SmoothL1": criterion = torch.nn.SmoothL1Loss() # setup models if gradient_checkpointing: for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): patch(m) mp.model.requires_grad_(False) comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) # Setup sampler and guider like in test script loss_map = {"loss": []} def loss_callback(loss): loss_map["loss"].append(loss) train_sampler = TrainSampler( criterion, optimizer, loss_callback=loss_callback, batch_size=batch_size, grad_acc=grad_accumulation_steps, total_steps=steps*grad_accumulation_steps, seed=seed, training_dtype=dtype, real_dataset=latents if multi_res else None ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input # Training loop try: # Generate dummy sigmas and noise sigmas = torch.tensor(range(num_images)) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) if multi_res: # use first latent as dummy latent if multi_res latents = latents[0].repeat(num_images, 1, 1, 1) guider.sample( noise.generate_noise({"samples": latents}), latents, train_sampler, sigmas, seed=noise.seed ) finally: for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer for adapter in all_weight_adapters: adapter.requires_grad_(False) for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) return (mp, lora_sd, loss_map, steps + existing_steps) class LoraModelLoader: def __init__(self): self.loaded_lora = None @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), } } RETURN_TYPES = ("MODEL",) OUTPUT_TOOLTIPS = ("The modified diffusion model.",) FUNCTION = "load_lora_model" CATEGORY = "loaders" DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." EXPERIMENTAL = True def load_lora_model(self, model, lora, strength_model): if strength_model == 0: return (model, ) model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) return (model_lora, ) class SaveLoRA: def __init__(self): self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return { "required": { "lora": ( IO.LORA_MODEL, { "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." }, ), "prefix": ( "STRING", { "default": "loras/ComfyUI_trained_lora", "tooltip": "The prefix to use for the saved LoRA file.", }, ), }, "optional": { "steps": ( IO.INT, { "forceInput": True, "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", }, ), }, } RETURN_TYPES = () FUNCTION = "save" CATEGORY = "loaders" EXPERIMENTAL = True OUTPUT_NODE = True def save(self, lora, prefix, steps=None): full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) if steps is None: output_checkpoint = f"{filename}_{counter:05}_.safetensors" else: output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) safetensors.torch.save_file(lora, output_checkpoint) return {} class LossGraphNode: def __init__(self): self.output_dir = folder_paths.get_temp_directory() @classmethod def INPUT_TYPES(s): return { "required": { "loss": (IO.LOSS_MAP, {"default": {}}), "filename_prefix": (IO.STRING, {"default": "loss_graph"}), }, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } RETURN_TYPES = () FUNCTION = "plot_loss" OUTPUT_NODE = True CATEGORY = "training" EXPERIMENTAL = True DESCRIPTION = "Plots the loss graph and saves it to the output directory." def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): loss_values = loss["loss"] width, height = 800, 480 margin = 40 img = Image.new( "RGB", (width + margin, height + margin), "white" ) # Extend canvas draw = ImageDraw.Draw(img) min_loss, max_loss = min(loss_values), max(loss_values) scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values] steps = len(loss_values) prev_point = (margin, height - int(scaled_loss[0] * height)) for i, l in enumerate(scaled_loss[1:], start=1): x = margin + int(i / steps * width) # Scale X properly y = height - int(l * height) draw.line([prev_point, (x, y)], fill="blue", width=2) prev_point = (x, y) draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis draw.line( [(margin, height), (width + margin, height)], fill="black", width=2 ) # X-axis font = None try: font = ImageFont.truetype("arial.ttf", 12) except IOError: font = ImageFont.load_default() # Add axis labels draw.text((5, height // 2), "Loss", font=font, fill="black") draw.text((width // 2, height + 10), "Steps", font=font, fill="black") # Add min/max loss values draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black") draw.text( (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" ) metadata = None if not args.disable_metadata: metadata = PngInfo() if prompt is not None: metadata.add_text("prompt", json.dumps(prompt)) if extra_pnginfo is not None: for x in extra_pnginfo: metadata.add_text(x, json.dumps(extra_pnginfo[x])) date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") img.save( os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), pnginfo=metadata, ) return { "ui": { "images": [ { "filename": f"{filename_prefix}_{date}.png", "subfolder": "", "type": "temp", } ] } } NODE_CLASS_MAPPINGS = { "TrainLoraNode": TrainLoraNode, "SaveLoRANode": SaveLoRA, "LoraModelLoader": LoraModelLoader, "LossGraphNode": LossGraphNode, } NODE_DISPLAY_NAME_MAPPINGS = { "TrainLoraNode": "Train LoRA", "SaveLoRANode": "Save LoRA Weights", "LoraModelLoader": "Load LoRA Model", "LossGraphNode": "Plot Loss Graph", }