Rewrite training system with new io schema

This commit is contained in:
Kohaku-Blueleaf 2025-11-08 15:06:48 +08:00
parent 28f22a517a
commit b3784a7da1

View File

@ -1,15 +1,13 @@
import datetime
import json
import logging import logging
import os import os
import numpy as np import numpy as np
import safetensors import safetensors
import torch import torch
from PIL import Image, ImageDraw, ImageFont
from PIL.PngImagePlugin import PngInfo
import torch.utils.checkpoint import torch.utils.checkpoint
import tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont
from typing_extensions import override
import comfy.samplers import comfy.samplers
import comfy.sd import comfy.sd
@ -18,9 +16,8 @@ import comfy.model_management
import comfy_extras.nodes_custom_sampler import comfy_extras.nodes_custom_sampler
import folder_paths import folder_paths
import node_helpers import node_helpers
from comfy.cli_args import args
from comfy.comfy_types.node_typing import IO
from comfy.weight_adapter import adapters, adapter_maps from comfy.weight_adapter import adapters, adapter_maps
from comfy_api.latest import ComfyExtension, io, ui
def make_batch_extra_option_dict(d, indicies, full_size=None): def make_batch_extra_option_dict(d, indicies, full_size=None):
@ -56,7 +53,18 @@ def process_cond_list(d, prefix=""):
class TrainSampler(comfy.samplers.Sampler): class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16, real_dataset=None): def __init__(
self,
loss_fn,
optimizer,
loss_callback=None,
batch_size=1,
grad_acc=1,
total_steps=1,
seed=0,
training_dtype=torch.bfloat16,
real_dataset=None,
):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optimizer self.optimizer = optimizer
self.loss_callback = loss_callback self.loss_callback = loss_callback
@ -67,51 +75,97 @@ class TrainSampler(comfy.samplers.Sampler):
self.training_dtype = training_dtype self.training_dtype = training_dtype
self.real_dataset: list[torch.Tensor] | None = real_dataset self.real_dataset: list[torch.Tensor] | None = real_dataset
def fwd_bwd(self, model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size): def fwd_bwd(
self,
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
):
xt = model_wrap.inner_model.model_sampling.noise_scaling( xt = model_wrap.inner_model.model_sampling.noise_scaling(
batch_sigmas, batch_sigmas, batch_noise, batch_latent, False
batch_noise,
batch_latent,
False
) )
x0 = model_wrap.inner_model.model_sampling.noise_scaling( x0 = model_wrap.inner_model.model_sampling.noise_scaling(
torch.zeros_like(batch_sigmas), torch.zeros_like(batch_sigmas),
torch.zeros_like(batch_noise), torch.zeros_like(batch_noise),
batch_latent, batch_latent,
False False,
) )
model_wrap.conds["positive"] = [ model_wrap.conds["positive"] = [cond[i] for i in indicies]
cond[i] for i in indicies batch_extra_args = make_batch_extra_option_dict(
] extra_args, indicies, full_size=dataset_size
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) )
with torch.autocast(xt.device.type, dtype=self.training_dtype): with torch.autocast(xt.device.type, dtype=self.training_dtype):
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) x0_pred = model_wrap(
xt.requires_grad_(True),
batch_sigmas.requires_grad_(True),
**batch_extra_args,
)
loss = self.loss_fn(x0_pred, x0) loss = self.loss_fn(x0_pred, x0)
loss.backward() if bwd:
bwd_loss = loss / self.grad_acc
bwd_loss.backward()
return loss return loss
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(
self,
model_wrap,
sigmas,
extra_args,
callback,
noise,
latent_image=None,
denoise_mask=None,
disable_pbar=False,
):
model_wrap.conds = process_cond_list(model_wrap.conds) model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"] cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0) dataset_size = sigmas.size(0)
torch.cuda.empty_cache() torch.cuda.empty_cache()
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): for i in (
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) pbar := tqdm.trange(
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() self.total_steps,
desc="Training LoRA",
smoothing=0.01,
disable=not comfy.utils.PROGRESS_BAR_ENABLED,
)
):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
self.seed + i * 1000
)
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
if self.real_dataset is None: if self.real_dataset is None:
batch_latent = torch.stack([latent_image[i] for i in indicies]) batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device
)
batch_sigmas = [ batch_sigmas = [
model_wrap.inner_model.model_sampling.percent_to_sigma( model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item() torch.rand((1,)).item()
) for _ in range(min(self.batch_size, dataset_size)) )
for _ in range(min(self.batch_size, dataset_size))
] ]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size) loss = self.fwd_bwd(
model_wrap,
batch_sigmas,
batch_noise,
batch_latent,
cond,
indicies,
extra_args,
dataset_size,
bwd=True,
)
if self.loss_callback: if self.loss_callback:
self.loss_callback(loss.item()) self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"}) pbar.set_postfix({"loss": f"{loss.item():.4f}"})
@ -119,19 +173,34 @@ class TrainSampler(comfy.samplers.Sampler):
total_loss = 0 total_loss = 0
for index in indicies: for index in indicies:
single_latent = self.real_dataset[index].to(latent_image) single_latent = self.real_dataset[index].to(latent_image)
batch_noise = noisegen.generate_noise({"samples": single_latent}).to(single_latent.device) batch_noise = noisegen.generate_noise(
batch_sigmas = model_wrap.inner_model.model_sampling.percent_to_sigma( {"samples": single_latent}
torch.rand((1,)).item() ).to(single_latent.device)
batch_sigmas = (
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
) )
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, single_latent, cond, [index], extra_args, dataset_size) loss = self.fwd_bwd(
total_loss += loss.item() model_wrap,
total_loss /= len(indicies) batch_sigmas,
batch_noise,
single_latent,
cond,
[index],
extra_args,
dataset_size,
bwd=False,
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.loss_callback: if self.loss_callback:
self.loss_callback(loss.item()) self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss/(index+1):.4f}"}) pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
if (i+1) % self.grad_acc == 0: if (i + 1) % self.grad_acc == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -173,10 +242,14 @@ def draw_loss_graph(loss_map, steps):
return img return img
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): def find_all_highest_child_module_with_forward(
model: torch.nn.Module, result=None, name=None
):
if result is None: if result is None:
result = [] result = []
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): elif hasattr(model, "forward") and not isinstance(
model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
):
result.append(model) result.append(model)
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
return result return result
@ -190,12 +263,13 @@ def patch(m):
if not hasattr(m, "forward"): if not hasattr(m, "forward"):
return return
org_forward = m.forward org_forward = m.forward
def fwd(args, kwargs): def fwd(args, kwargs):
return org_forward(*args, **kwargs) return org_forward(*args, **kwargs)
def checkpointing_fwd(*args, **kwargs): def checkpointing_fwd(*args, **kwargs):
return torch.utils.checkpoint.checkpoint( return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
fwd, args, kwargs, use_reentrant=False
)
m.org_forward = org_forward m.org_forward = org_forward
m.forward = checkpointing_fwd m.forward = checkpointing_fwd
@ -206,130 +280,120 @@ def unpatch(m):
del m.org_forward del m.org_forward
class TrainLoraNode: class TrainLoraNode(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="TrainLoraNode",
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), display_name="Train LoRA",
"latents": ( category="training",
"LATENT", is_experimental=True,
{ is_input_list=True, # All inputs become lists
"tooltip": "The Latents to use for training, serve as dataset/input of the model." inputs=[
}, io.Model.Input("model", tooltip="The model to train the LoRA on."),
io.Latent.Input(
"latents",
tooltip="The Latents to use for training, serve as dataset/input of the model.",
), ),
"positive": ( io.Conditioning.Input(
IO.CONDITIONING, "positive", tooltip="The positive conditioning to use for training."
{"tooltip": "The positive conditioning to use for training."},
), ),
"batch_size": ( io.Int.Input(
IO.INT, "batch_size",
{ default=1,
"default": 1, min=1,
"min": 1, max=10000,
"max": 10000, tooltip="The batch size to use for training.",
"step": 1,
"tooltip": "The batch size to use for training.",
},
), ),
"grad_accumulation_steps": ( io.Int.Input(
IO.INT, "grad_accumulation_steps",
{ default=1,
"default": 1, min=1,
"min": 1, max=1024,
"max": 1024, tooltip="The number of gradient accumulation steps to use for training.",
"step": 1,
"tooltip": "The number of gradient accumulation steps to use for training.",
}
), ),
"steps": ( io.Int.Input(
IO.INT, "steps",
{ default=16,
"default": 16, min=1,
"min": 1, max=100000,
"max": 100000, tooltip="The number of steps to train the LoRA for.",
"tooltip": "The number of steps to train the LoRA for.",
},
), ),
"learning_rate": ( io.Float.Input(
IO.FLOAT, "learning_rate",
{ default=0.0005,
"default": 0.0005, min=0.0000001,
"min": 0.0000001, max=1.0,
"max": 1.0, step=0.0000001,
"step": 0.000001, tooltip="The learning rate to use for training.",
"tooltip": "The learning rate to use for training.",
},
), ),
"rank": ( io.Int.Input(
IO.INT, "rank",
{ default=8,
"default": 8, min=1,
"min": 1, max=128,
"max": 128, tooltip="The rank of the LoRA layers.",
"tooltip": "The rank of the LoRA layers.",
},
), ),
"optimizer": ( io.Combo.Input(
["AdamW", "Adam", "SGD", "RMSprop"], "optimizer",
{ options=["AdamW", "Adam", "SGD", "RMSprop"],
"default": "AdamW", default="AdamW",
"tooltip": "The optimizer to use for training.", tooltip="The optimizer to use for training.",
},
), ),
"loss_function": ( io.Combo.Input(
["MSE", "L1", "Huber", "SmoothL1"], "loss_function",
{ options=["MSE", "L1", "Huber", "SmoothL1"],
"default": "MSE", default="MSE",
"tooltip": "The loss function to use for training.", tooltip="The loss function to use for training.",
},
), ),
"seed": ( io.Int.Input(
IO.INT, "seed",
{ default=0,
"default": 0, min=0,
"min": 0, max=0xFFFFFFFFFFFFFFFF,
"max": 0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
},
), ),
"training_dtype": ( io.Combo.Input(
["bf16", "fp32"], "training_dtype",
{"default": "bf16", "tooltip": "The dtype to use for training."}, options=["bf16", "fp32"],
default="bf16",
tooltip="The dtype to use for training.",
), ),
"lora_dtype": ( io.Combo.Input(
["bf16", "fp32"], "lora_dtype",
{"default": "bf16", "tooltip": "The dtype to use for lora."}, options=["bf16", "fp32"],
default="bf16",
tooltip="The dtype to use for lora.",
), ),
"algorithm": ( io.Combo.Input(
list(adapter_maps.keys()), "algorithm",
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."}, options=list(adapter_maps.keys()),
default=list(adapter_maps.keys())[0],
tooltip="The algorithm to use for training.",
), ),
"gradient_checkpointing": ( io.Boolean.Input(
IO.BOOLEAN, "gradient_checkpointing",
{ default=True,
"default": True, tooltip="Use gradient checkpointing for training.",
"tooltip": "Use gradient checkpointing for training.",
}
), ),
"existing_lora": ( io.Combo.Input(
folder_paths.get_filename_list("loras") + ["[None]"], "existing_lora",
{ options=folder_paths.get_filename_list("loras") + ["[None]"],
"default": "[None]", default="[None]",
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.", tooltip="The existing LoRA to append to. Set to None for new LoRA.",
},
), ),
}, ],
} outputs=[
io.Model.Output(tooltip="Model with LoRA applied"),
io.Custom("LORA_MODEL").Output(tooltip="LoRA weights"),
io.Custom("LOSS_MAP").Output(tooltip="Loss history"),
io.Int.Output(tooltip="Total training steps"),
],
)
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) @classmethod
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") def execute(
FUNCTION = "train" cls,
CATEGORY = "training"
EXPERIMENTAL = True
def train(
self,
model, model,
latents, latents,
positive, positive,
@ -347,13 +411,56 @@ class TrainLoraNode:
gradient_checkpointing, gradient_checkpointing,
existing_lora, existing_lora,
): ):
# Extract scalars from lists (due to is_input_list=True)
model = model[0]
batch_size = batch_size[0]
steps = steps[0]
grad_accumulation_steps = grad_accumulation_steps[0]
learning_rate = learning_rate[0]
rank = rank[0]
optimizer = optimizer[0]
loss_function = loss_function[0]
seed = seed[0]
training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
existing_lora = existing_lora[0]
# Handle latents - either single dict or list of dicts
if len(latents) == 1:
latents = latents[0]["samples"] # Single latent dict
else:
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
latents = latent_list
# Handle conditioning - either single list or list of lists
if len(positive) == 1:
positive = positive[0] # Single conditioning list
else:
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
positive = flat_positive
mp = model.clone() mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype) dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype) mp.set_model_compute_dtype(dtype)
# latents here can be list of different size latent or one large batch # latents here can be list of different size latent or one large batch
latents = latents["samples"]
if isinstance(latents, list): if isinstance(latents, list):
all_shapes = set() all_shapes = set()
latents = [t.to(dtype) for t in latents] latents = [t.to(dtype) for t in latents]
@ -366,8 +473,8 @@ class TrainLoraNode:
multi_res = False multi_res = False
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
num_images = len(latents) num_images = len(latents)
elif isinstance(latents, list): elif isinstance(latents, torch.Tensor):
latents = latents["samples"].to(dtype) latents = latents.to(dtype)
num_images = latents.shape[0] num_images = latents.shape[0]
else: else:
logging.error(f"Invalid latents type: {type(latents)}") logging.error(f"Invalid latents type: {type(latents)}")
@ -403,9 +510,7 @@ class TrainLoraNode:
shape = m.weight.shape shape = m.weight.shape
if len(shape) >= 2: if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get( dora_scale = existing_weights.get(f"{key}.dora_scale", None)
f"{key}.dora_scale", None
)
for adapter_cls in adapters: for adapter_cls in adapters:
existing_adapter = adapter_cls.load( existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale n, existing_weights, alpha, dora_scale
@ -417,7 +522,9 @@ class TrainLoraNode:
adapter_cls = adapter_maps[algorithm] adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None: if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype) train_adapter = existing_adapter.to_train().to(
lora_dtype
)
else: else:
# Use LoRA with alpha=1.0 by default # Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train( train_adapter = adapter_cls.create_train(
@ -441,7 +548,9 @@ class TrainLoraNode:
if hasattr(m, "bias") and m.bias is not None: if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n) key = "{}.bias".format(n)
bias = torch.nn.Parameter( bias = torch.nn.Parameter(
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) torch.zeros(
m.bias.shape, dtype=lora_dtype, requires_grad=True
)
) )
bias_module = BiasDiff(bias) bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias lora_sd["{}.diff_b".format(n)] = bias
@ -469,25 +578,31 @@ class TrainLoraNode:
# setup models # setup models
if gradient_checkpointing: if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model
):
patch(m) patch(m)
mp.model.requires_grad_(False) mp.model.requires_grad_(False)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=True
)
# Setup sampler and guider like in test script # Setup sampler and guider like in test script
loss_map = {"loss": []} loss_map = {"loss": []}
def loss_callback(loss): def loss_callback(loss):
loss_map["loss"].append(loss) loss_map["loss"].append(loss)
train_sampler = TrainSampler( train_sampler = TrainSampler(
criterion, criterion,
optimizer, optimizer,
loss_callback=loss_callback, loss_callback=loss_callback,
batch_size=batch_size, batch_size=batch_size,
grad_acc=grad_accumulation_steps, grad_acc=grad_accumulation_steps,
total_steps=steps*grad_accumulation_steps, total_steps=steps * grad_accumulation_steps,
seed=seed, seed=seed,
training_dtype=dtype, training_dtype=dtype,
real_dataset=latents if multi_res else None real_dataset=latents if multi_res else None,
) )
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input guider.set_conds(positive) # Set conditioning from input
@ -505,7 +620,7 @@ class TrainLoraNode:
latents, latents,
train_sampler, train_sampler,
sigmas, sigmas,
seed=noise.seed seed=noise.seed,
) )
finally: finally:
for m in mp.model.modules(): for m in mp.model.modules():
@ -518,111 +633,116 @@ class TrainLoraNode:
for param in lora_sd: for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype) lora_sd[param] = lora_sd[param].to(lora_dtype)
return (mp, lora_sd, loss_map, steps + existing_steps) return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader: class LoraModelLoader(io.ComfyNode):
def __init__(self): @classmethod
self.loaded_lora = None def define_schema(cls):
return io.Schema(
node_id="LoraModelLoader",
display_name="Load LoRA Model",
category="loaders",
is_experimental=True,
inputs=[
io.Model.Input(
"model", tooltip="The diffusion model the LoRA will be applied to."
),
io.Custom("LORA_MODEL").Input(
"lora", tooltip="The LoRA model to apply to the diffusion model."
),
io.Float.Input(
"strength_model",
default=1.0,
min=-100.0,
max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.",
),
],
outputs=[
io.Model.Output(tooltip="The modified diffusion model."),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, model, lora, strength_model):
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}),
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
}
}
RETURN_TYPES = ("MODEL",)
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
FUNCTION = "load_lora_model"
CATEGORY = "loaders"
DESCRIPTION = "Load Trained LoRA weights from Train LoRA node."
EXPERIMENTAL = True
def load_lora_model(self, model, lora, strength_model):
if strength_model == 0: if strength_model == 0:
return (model, ) return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) model_lora, _ = comfy.sd.load_lora_for_models(
return (model_lora, ) model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora)
class SaveLoRA: class SaveLoRA(io.ComfyNode):
def __init__(self): @classmethod
self.output_dir = folder_paths.get_output_directory() def define_schema(cls):
return io.Schema(
node_id="SaveLoRA",
display_name="Save LoRA Weights",
category="loaders",
is_experimental=True,
is_output_node=True,
inputs=[
io.Custom("LORA_MODEL").Input(
"lora",
tooltip="The LoRA model to save. Do not use the model with LoRA layers.",
),
io.String.Input(
"prefix",
default="loras/ComfyUI_trained_lora",
tooltip="The prefix to use for the saved LoRA file.",
),
io.Int.Input(
"steps",
optional=True,
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
),
],
outputs=[],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, lora, prefix, steps=None):
return { output_dir = folder_paths.get_output_directory()
"required": { full_output_folder, filename, counter, subfolder, filename_prefix = (
"lora": ( folder_paths.get_save_image_path(prefix, output_dir)
IO.LORA_MODEL, )
{
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
},
),
"prefix": (
"STRING",
{
"default": "loras/ComfyUI_trained_lora",
"tooltip": "The prefix to use for the saved LoRA file.",
},
),
},
"optional": {
"steps": (
IO.INT,
{
"forceInput": True,
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
},
),
},
}
RETURN_TYPES = ()
FUNCTION = "save"
CATEGORY = "loaders"
EXPERIMENTAL = True
OUTPUT_NODE = True
def save(self, lora, prefix, steps=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir)
if steps is None: if steps is None:
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
else: else:
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
safetensors.torch.save_file(lora, output_checkpoint) safetensors.torch.save_file(lora, output_checkpoint)
return {} return io.NodeOutput()
class LossGraphNode: class LossGraphNode(io.ComfyNode):
def __init__(self): @classmethod
self.output_dir = folder_paths.get_temp_directory() def define_schema(cls):
return io.Schema(
node_id="LossGraphNode",
display_name="Plot Loss Graph",
category="training",
is_experimental=True,
is_output_node=True,
inputs=[
io.Custom("LOSS_MAP").Input(
"loss", tooltip="Loss map from training node."
),
io.String.Input(
"filename_prefix",
default="loss_graph",
tooltip="Prefix for the saved loss graph image.",
),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None):
return {
"required": {
"loss": (IO.LOSS_MAP, {"default": {}}),
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "plot_loss"
OUTPUT_NODE = True
CATEGORY = "training"
EXPERIMENTAL = True
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
loss_values = loss["loss"] loss_values = loss["loss"]
width, height = 800, 480 width, height = 800, 480
margin = 40 margin = 40
@ -665,43 +785,27 @@ class LossGraphNode:
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
) )
metadata = None # Convert PIL image to tensor for PreviewImage
if not args.disable_metadata: img_array = np.array(img).astype(np.float32) / 255.0
metadata = PngInfo() img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3]
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") # Return preview UI
img.save( return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls))
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
pnginfo=metadata,
)
return {
"ui": {
"images": [
{
"filename": f"{filename_prefix}_{date}.png",
"subfolder": "",
"type": "temp",
}
]
}
}
NODE_CLASS_MAPPINGS = { # ========== Extension Setup ==========
"TrainLoraNode": TrainLoraNode,
"SaveLoRANode": SaveLoRA,
"LoraModelLoader": LoraModelLoader,
"LossGraphNode": LossGraphNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TrainLoraNode": "Train LoRA", class TrainingExtension(ComfyExtension):
"SaveLoRANode": "Save LoRA Weights", @override
"LoraModelLoader": "Load LoRA Model", async def get_node_list(self) -> list[type[io.ComfyNode]]:
"LossGraphNode": "Plot Loss Graph", return [
} TrainLoraNode,
LoraModelLoader,
SaveLoRA,
LossGraphNode,
]
async def comfy_entrypoint() -> TrainingExtension:
return TrainingExtension()