mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 15:50:49 +08:00
Rewrite training system with new io schema
This commit is contained in:
parent
28f22a517a
commit
b3784a7da1
@ -1,15 +1,13 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import safetensors
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import torch.utils.checkpoint
|
||||
import tqdm
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.samplers
|
||||
import comfy.sd
|
||||
@ -18,9 +16,8 @@ import comfy.model_management
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy.weight_adapter import adapters, adapter_maps
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
|
||||
|
||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
@ -56,7 +53,18 @@ def process_cond_list(d, prefix=""):
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16, real_dataset=None):
|
||||
def __init__(
|
||||
self,
|
||||
loss_fn,
|
||||
optimizer,
|
||||
loss_callback=None,
|
||||
batch_size=1,
|
||||
grad_acc=1,
|
||||
total_steps=1,
|
||||
seed=0,
|
||||
training_dtype=torch.bfloat16,
|
||||
real_dataset=None,
|
||||
):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
@ -67,51 +75,97 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.training_dtype = training_dtype
|
||||
self.real_dataset: list[torch.Tensor] | None = real_dataset
|
||||
|
||||
def fwd_bwd(self, model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size):
|
||||
def fwd_bwd(
|
||||
self,
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
cond,
|
||||
indicies,
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=True,
|
||||
):
|
||||
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
False
|
||||
batch_sigmas, batch_noise, batch_latent, False
|
||||
)
|
||||
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(batch_sigmas),
|
||||
torch.zeros_like(batch_noise),
|
||||
batch_latent,
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
model_wrap.conds["positive"] = [
|
||||
cond[i] for i in indicies
|
||||
]
|
||||
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||
model_wrap.conds["positive"] = [cond[i] for i in indicies]
|
||||
batch_extra_args = make_batch_extra_option_dict(
|
||||
extra_args, indicies, full_size=dataset_size
|
||||
)
|
||||
|
||||
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||
x0_pred = model_wrap(
|
||||
xt.requires_grad_(True),
|
||||
batch_sigmas.requires_grad_(True),
|
||||
**batch_extra_args,
|
||||
)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss.backward()
|
||||
if bwd:
|
||||
bwd_loss = loss / self.grad_acc
|
||||
bwd_loss.backward()
|
||||
return loss
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
def sample(
|
||||
self,
|
||||
model_wrap,
|
||||
sigmas,
|
||||
extra_args,
|
||||
callback,
|
||||
noise,
|
||||
latent_image=None,
|
||||
denoise_mask=None,
|
||||
disable_pbar=False,
|
||||
):
|
||||
model_wrap.conds = process_cond_list(model_wrap.conds)
|
||||
cond = model_wrap.conds["positive"]
|
||||
dataset_size = sigmas.size(0)
|
||||
torch.cuda.empty_cache()
|
||||
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||
for i in (
|
||||
pbar := tqdm.trange(
|
||||
self.total_steps,
|
||||
desc="Training LoRA",
|
||||
smoothing=0.01,
|
||||
disable=not comfy.utils.PROGRESS_BAR_ENABLED,
|
||||
)
|
||||
):
|
||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
|
||||
self.seed + i * 1000
|
||||
)
|
||||
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
|
||||
|
||||
if self.real_dataset is None:
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
|
||||
batch_latent.device
|
||||
)
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
) for _ in range(min(self.batch_size, dataset_size))
|
||||
)
|
||||
for _ in range(min(self.batch_size, dataset_size))
|
||||
]
|
||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||
|
||||
loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, batch_latent, cond, indicies, extra_args, dataset_size)
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
cond,
|
||||
indicies,
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=True,
|
||||
)
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
@ -119,19 +173,34 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
total_loss = 0
|
||||
for index in indicies:
|
||||
single_latent = self.real_dataset[index].to(latent_image)
|
||||
batch_noise = noisegen.generate_noise({"samples": single_latent}).to(single_latent.device)
|
||||
batch_sigmas = model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
batch_noise = noisegen.generate_noise(
|
||||
{"samples": single_latent}
|
||||
).to(single_latent.device)
|
||||
batch_sigmas = (
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
)
|
||||
)
|
||||
batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device)
|
||||
loss = self.fwd_bwd(model_wrap, batch_sigmas, batch_noise, single_latent, cond, [index], extra_args, dataset_size)
|
||||
total_loss += loss.item()
|
||||
total_loss /= len(indicies)
|
||||
loss = self.fwd_bwd(
|
||||
model_wrap,
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
single_latent,
|
||||
cond,
|
||||
[index],
|
||||
extra_args,
|
||||
dataset_size,
|
||||
bwd=False,
|
||||
)
|
||||
total_loss += loss
|
||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||
total_loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss/(index+1):.4f}"})
|
||||
self.loss_callback(total_loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||
|
||||
if (i+1) % self.grad_acc == 0:
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
@ -173,10 +242,14 @@ def draw_loss_graph(loss_map, steps):
|
||||
return img
|
||||
|
||||
|
||||
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
|
||||
def find_all_highest_child_module_with_forward(
|
||||
model: torch.nn.Module, result=None, name=None
|
||||
):
|
||||
if result is None:
|
||||
result = []
|
||||
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
||||
elif hasattr(model, "forward") and not isinstance(
|
||||
model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
|
||||
):
|
||||
result.append(model)
|
||||
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
||||
return result
|
||||
@ -190,12 +263,13 @@ def patch(m):
|
||||
if not hasattr(m, "forward"):
|
||||
return
|
||||
org_forward = m.forward
|
||||
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
fwd, args, kwargs, use_reentrant=False
|
||||
)
|
||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||
|
||||
m.org_forward = org_forward
|
||||
m.forward = checkpointing_fwd
|
||||
|
||||
@ -206,130 +280,120 @@ def unpatch(m):
|
||||
del m.org_forward
|
||||
|
||||
|
||||
class TrainLoraNode:
|
||||
class TrainLoraNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}),
|
||||
"latents": (
|
||||
"LATENT",
|
||||
{
|
||||
"tooltip": "The Latents to use for training, serve as dataset/input of the model."
|
||||
},
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TrainLoraNode",
|
||||
display_name="Train LoRA",
|
||||
category="training",
|
||||
is_experimental=True,
|
||||
is_input_list=True, # All inputs become lists
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to train the LoRA on."),
|
||||
io.Latent.Input(
|
||||
"latents",
|
||||
tooltip="The Latents to use for training, serve as dataset/input of the model.",
|
||||
),
|
||||
"positive": (
|
||||
IO.CONDITIONING,
|
||||
{"tooltip": "The positive conditioning to use for training."},
|
||||
io.Conditioning.Input(
|
||||
"positive", tooltip="The positive conditioning to use for training."
|
||||
),
|
||||
"batch_size": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 10000,
|
||||
"step": 1,
|
||||
"tooltip": "The batch size to use for training.",
|
||||
},
|
||||
io.Int.Input(
|
||||
"batch_size",
|
||||
default=1,
|
||||
min=1,
|
||||
max=10000,
|
||||
tooltip="The batch size to use for training.",
|
||||
),
|
||||
"grad_accumulation_steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 1024,
|
||||
"step": 1,
|
||||
"tooltip": "The number of gradient accumulation steps to use for training.",
|
||||
}
|
||||
io.Int.Input(
|
||||
"grad_accumulation_steps",
|
||||
default=1,
|
||||
min=1,
|
||||
max=1024,
|
||||
tooltip="The number of gradient accumulation steps to use for training.",
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 16,
|
||||
"min": 1,
|
||||
"max": 100000,
|
||||
"tooltip": "The number of steps to train the LoRA for.",
|
||||
},
|
||||
io.Int.Input(
|
||||
"steps",
|
||||
default=16,
|
||||
min=1,
|
||||
max=100000,
|
||||
tooltip="The number of steps to train the LoRA for.",
|
||||
),
|
||||
"learning_rate": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.0005,
|
||||
"min": 0.0000001,
|
||||
"max": 1.0,
|
||||
"step": 0.000001,
|
||||
"tooltip": "The learning rate to use for training.",
|
||||
},
|
||||
io.Float.Input(
|
||||
"learning_rate",
|
||||
default=0.0005,
|
||||
min=0.0000001,
|
||||
max=1.0,
|
||||
step=0.0000001,
|
||||
tooltip="The learning rate to use for training.",
|
||||
),
|
||||
"rank": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 8,
|
||||
"min": 1,
|
||||
"max": 128,
|
||||
"tooltip": "The rank of the LoRA layers.",
|
||||
},
|
||||
io.Int.Input(
|
||||
"rank",
|
||||
default=8,
|
||||
min=1,
|
||||
max=128,
|
||||
tooltip="The rank of the LoRA layers.",
|
||||
),
|
||||
"optimizer": (
|
||||
["AdamW", "Adam", "SGD", "RMSprop"],
|
||||
{
|
||||
"default": "AdamW",
|
||||
"tooltip": "The optimizer to use for training.",
|
||||
},
|
||||
io.Combo.Input(
|
||||
"optimizer",
|
||||
options=["AdamW", "Adam", "SGD", "RMSprop"],
|
||||
default="AdamW",
|
||||
tooltip="The optimizer to use for training.",
|
||||
),
|
||||
"loss_function": (
|
||||
["MSE", "L1", "Huber", "SmoothL1"],
|
||||
{
|
||||
"default": "MSE",
|
||||
"tooltip": "The loss function to use for training.",
|
||||
},
|
||||
io.Combo.Input(
|
||||
"loss_function",
|
||||
options=["MSE", "L1", "Huber", "SmoothL1"],
|
||||
default="MSE",
|
||||
tooltip="The loss function to use for training.",
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
|
||||
},
|
||||
io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)",
|
||||
),
|
||||
"training_dtype": (
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for training."},
|
||||
io.Combo.Input(
|
||||
"training_dtype",
|
||||
options=["bf16", "fp32"],
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for training.",
|
||||
),
|
||||
"lora_dtype": (
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||
io.Combo.Input(
|
||||
"lora_dtype",
|
||||
options=["bf16", "fp32"],
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for lora.",
|
||||
),
|
||||
"algorithm": (
|
||||
list(adapter_maps.keys()),
|
||||
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
|
||||
io.Combo.Input(
|
||||
"algorithm",
|
||||
options=list(adapter_maps.keys()),
|
||||
default=list(adapter_maps.keys())[0],
|
||||
tooltip="The algorithm to use for training.",
|
||||
),
|
||||
"gradient_checkpointing": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Use gradient checkpointing for training.",
|
||||
}
|
||||
io.Boolean.Input(
|
||||
"gradient_checkpointing",
|
||||
default=True,
|
||||
tooltip="Use gradient checkpointing for training.",
|
||||
),
|
||||
"existing_lora": (
|
||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
{
|
||||
"default": "[None]",
|
||||
"tooltip": "The existing LoRA to append to. Set to None for new LoRA.",
|
||||
},
|
||||
io.Combo.Input(
|
||||
"existing_lora",
|
||||
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
default="[None]",
|
||||
tooltip="The existing LoRA to append to. Set to None for new LoRA.",
|
||||
),
|
||||
},
|
||||
}
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="Model with LoRA applied"),
|
||||
io.Custom("LORA_MODEL").Output(tooltip="LoRA weights"),
|
||||
io.Custom("LOSS_MAP").Output(tooltip="Loss history"),
|
||||
io.Int.Output(tooltip="Total training steps"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT)
|
||||
RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps")
|
||||
FUNCTION = "train"
|
||||
CATEGORY = "training"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def train(
|
||||
self,
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
model,
|
||||
latents,
|
||||
positive,
|
||||
@ -347,13 +411,56 @@ class TrainLoraNode:
|
||||
gradient_checkpointing,
|
||||
existing_lora,
|
||||
):
|
||||
# Extract scalars from lists (due to is_input_list=True)
|
||||
model = model[0]
|
||||
batch_size = batch_size[0]
|
||||
steps = steps[0]
|
||||
grad_accumulation_steps = grad_accumulation_steps[0]
|
||||
learning_rate = learning_rate[0]
|
||||
rank = rank[0]
|
||||
optimizer = optimizer[0]
|
||||
loss_function = loss_function[0]
|
||||
seed = seed[0]
|
||||
training_dtype = training_dtype[0]
|
||||
lora_dtype = lora_dtype[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
existing_lora = existing_lora[0]
|
||||
|
||||
# Handle latents - either single dict or list of dicts
|
||||
if len(latents) == 1:
|
||||
latents = latents[0]["samples"] # Single latent dict
|
||||
else:
|
||||
latent_list = []
|
||||
for latent in latents:
|
||||
latent = latent["samples"]
|
||||
bs = latent.shape[0]
|
||||
if bs != 1:
|
||||
for sub_latent in latent:
|
||||
latent_list.append(sub_latent[None])
|
||||
else:
|
||||
latent_list.append(latent)
|
||||
latents = latent_list
|
||||
|
||||
# Handle conditioning - either single list or list of lists
|
||||
if len(positive) == 1:
|
||||
positive = positive[0] # Single conditioning list
|
||||
else:
|
||||
# Multiple conditioning lists - flatten
|
||||
flat_positive = []
|
||||
for cond in positive:
|
||||
if isinstance(cond, list):
|
||||
flat_positive.extend(cond)
|
||||
else:
|
||||
flat_positive.append(cond)
|
||||
positive = flat_positive
|
||||
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
# latents here can be list of different size latent or one large batch
|
||||
latents = latents["samples"]
|
||||
if isinstance(latents, list):
|
||||
all_shapes = set()
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
@ -366,8 +473,8 @@ class TrainLoraNode:
|
||||
multi_res = False
|
||||
latents = torch.cat(latents, dim=0)
|
||||
num_images = len(latents)
|
||||
elif isinstance(latents, list):
|
||||
latents = latents["samples"].to(dtype)
|
||||
elif isinstance(latents, torch.Tensor):
|
||||
latents = latents.to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
else:
|
||||
logging.error(f"Invalid latents type: {type(latents)}")
|
||||
@ -403,9 +510,7 @@ class TrainLoraNode:
|
||||
shape = m.weight.shape
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(
|
||||
f"{key}.dora_scale", None
|
||||
)
|
||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||
for adapter_cls in adapters:
|
||||
existing_adapter = adapter_cls.load(
|
||||
n, existing_weights, alpha, dora_scale
|
||||
@ -417,7 +522,9 @@ class TrainLoraNode:
|
||||
adapter_cls = adapter_maps[algorithm]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
train_adapter = existing_adapter.to_train().to(
|
||||
lora_dtype
|
||||
)
|
||||
else:
|
||||
# Use LoRA with alpha=1.0 by default
|
||||
train_adapter = adapter_cls.create_train(
|
||||
@ -441,7 +548,9 @@ class TrainLoraNode:
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
key = "{}.bias".format(n)
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||
torch.zeros(
|
||||
m.bias.shape, dtype=lora_dtype, requires_grad=True
|
||||
)
|
||||
)
|
||||
bias_module = BiasDiff(bias)
|
||||
lora_sd["{}.diff_b".format(n)] = bias
|
||||
@ -469,25 +578,31 @@ class TrainLoraNode:
|
||||
|
||||
# setup models
|
||||
if gradient_checkpointing:
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
for m in find_all_highest_child_module_with_forward(
|
||||
mp.model.diffusion_model
|
||||
):
|
||||
patch(m)
|
||||
mp.model.requires_grad_(False)
|
||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||
comfy.model_management.load_models_gpu(
|
||||
[mp], memory_required=1e20, force_full_load=True
|
||||
)
|
||||
|
||||
# Setup sampler and guider like in test script
|
||||
loss_map = {"loss": []}
|
||||
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
|
||||
train_sampler = TrainSampler(
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps*grad_accumulation_steps,
|
||||
total_steps=steps * grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None
|
||||
real_dataset=latents if multi_res else None,
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
@ -505,7 +620,7 @@ class TrainLoraNode:
|
||||
latents,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed
|
||||
seed=noise.seed,
|
||||
)
|
||||
finally:
|
||||
for m in mp.model.modules():
|
||||
@ -518,111 +633,116 @@ class TrainLoraNode:
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
|
||||
return (mp, lora_sd, loss_map, steps + existing_steps)
|
||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader:
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
class LoraModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoraModelLoader",
|
||||
display_name="Load LoRA Model",
|
||||
category="loaders",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input(
|
||||
"model", tooltip="The diffusion model the LoRA will be applied to."
|
||||
),
|
||||
io.Custom("LORA_MODEL").Input(
|
||||
"lora", tooltip="The LoRA model to apply to the diffusion model."
|
||||
),
|
||||
io.Float.Input(
|
||||
"strength_model",
|
||||
default=1.0,
|
||||
min=-100.0,
|
||||
max=100.0,
|
||||
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The modified diffusion model."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||
"lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
|
||||
FUNCTION = "load_lora_model"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
DESCRIPTION = "Load Trained LoRA weights from Train LoRA node."
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def load_lora_model(self, model, lora, strength_model):
|
||||
def execute(cls, model, lora, strength_model):
|
||||
if strength_model == 0:
|
||||
return (model, )
|
||||
return io.NodeOutput(model)
|
||||
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
||||
return (model_lora, )
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(
|
||||
model, None, lora, strength_model, 0
|
||||
)
|
||||
return io.NodeOutput(model_lora)
|
||||
|
||||
|
||||
class SaveLoRA:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
class SaveLoRA(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveLoRA",
|
||||
display_name="Save LoRA Weights",
|
||||
category="loaders",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.Custom("LORA_MODEL").Input(
|
||||
"lora",
|
||||
tooltip="The LoRA model to save. Do not use the model with LoRA layers.",
|
||||
),
|
||||
io.String.Input(
|
||||
"prefix",
|
||||
default="loras/ComfyUI_trained_lora",
|
||||
tooltip="The prefix to use for the saved LoRA file.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"steps",
|
||||
optional=True,
|
||||
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"lora": (
|
||||
IO.LORA_MODEL,
|
||||
{
|
||||
"tooltip": "The LoRA model to save. Do not use the model with LoRA layers."
|
||||
},
|
||||
),
|
||||
"prefix": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "loras/ComfyUI_trained_lora",
|
||||
"tooltip": "The prefix to use for the saved LoRA file.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"forceInput": True,
|
||||
"tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
CATEGORY = "loaders"
|
||||
EXPERIMENTAL = True
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def save(self, lora, prefix, steps=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir)
|
||||
def execute(cls, lora, prefix, steps=None):
|
||||
output_dir = folder_paths.get_output_directory()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = (
|
||||
folder_paths.get_save_image_path(prefix, output_dir)
|
||||
)
|
||||
if steps is None:
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
else:
|
||||
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
safetensors.torch.save_file(lora, output_checkpoint)
|
||||
return {}
|
||||
return io.NodeOutput()
|
||||
|
||||
|
||||
class LossGraphNode:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
class LossGraphNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LossGraphNode",
|
||||
display_name="Plot Loss Graph",
|
||||
category="training",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.Custom("LOSS_MAP").Input(
|
||||
"loss", tooltip="Loss map from training node."
|
||||
),
|
||||
io.String.Input(
|
||||
"filename_prefix",
|
||||
default="loss_graph",
|
||||
tooltip="Prefix for the saved loss graph image.",
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"loss": (IO.LOSS_MAP, {"default": {}}),
|
||||
"filename_prefix": (IO.STRING, {"default": "loss_graph"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "plot_loss"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "training"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Plots the loss graph and saves it to the output directory."
|
||||
|
||||
def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
loss_values = loss["loss"]
|
||||
width, height = 800, 480
|
||||
margin = 40
|
||||
@ -665,43 +785,27 @@ class LossGraphNode:
|
||||
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
||||
)
|
||||
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add_text("prompt", json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||
# Convert PIL image to tensor for PreviewImage
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3]
|
||||
|
||||
date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
img.save(
|
||||
os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"),
|
||||
pnginfo=metadata,
|
||||
)
|
||||
return {
|
||||
"ui": {
|
||||
"images": [
|
||||
{
|
||||
"filename": f"{filename_prefix}_{date}.png",
|
||||
"subfolder": "",
|
||||
"type": "temp",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
# Return preview UI
|
||||
return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TrainLoraNode": TrainLoraNode,
|
||||
"SaveLoRANode": SaveLoRA,
|
||||
"LoraModelLoader": LoraModelLoader,
|
||||
"LossGraphNode": LossGraphNode,
|
||||
}
|
||||
# ========== Extension Setup ==========
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TrainLoraNode": "Train LoRA",
|
||||
"SaveLoRANode": "Save LoRA Weights",
|
||||
"LoraModelLoader": "Load LoRA Model",
|
||||
"LossGraphNode": "Plot Loss Graph",
|
||||
}
|
||||
|
||||
class TrainingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TrainLoraNode,
|
||||
LoraModelLoader,
|
||||
SaveLoRA,
|
||||
LossGraphNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> TrainingExtension:
|
||||
return TrainingExtension()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user