Refactoring with better layout for maintainability

This commit is contained in:
Kohaku-Blueleaf 2025-12-01 23:53:20 +08:00
parent 7a93c55a9f
commit bf573e94a2

View File

@ -80,6 +80,14 @@ class TrainSampler(comfy.samplers.Sampler):
self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi) self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi)
# Precompute bucket offsets and weights for sampling # Precompute bucket offsets and weights for sampling
if bucket_latents is not None: if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
else:
self.bucket_offsets = None
self.bucket_weights = None
self.num_images = None
def _init_bucket_data(self, bucket_latents):
"""Initialize bucket offsets and weights for sampling."""
self.bucket_offsets = [0] self.bucket_offsets = [0]
bucket_sizes = [] bucket_sizes = []
for lat in bucket_latents: for lat in bucket_latents:
@ -88,10 +96,6 @@ class TrainSampler(comfy.samplers.Sampler):
self.num_images = self.bucket_offsets[-1] self.num_images = self.bucket_offsets[-1]
# Weights for sampling buckets proportional to their size # Weights for sampling buckets proportional to their size
self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32) self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32)
else:
self.bucket_offsets = None
self.bucket_weights = None
self.num_images = None
def fwd_bwd( def fwd_bwd(
self, self,
@ -132,36 +136,19 @@ class TrainSampler(comfy.samplers.Sampler):
bwd_loss.backward() bwd_loss.backward()
return loss return loss
def sample( def _generate_batch_sigmas(self, model_wrap, batch_size, device):
self, """Generate random sigma values for a batch."""
model_wrap, batch_sigmas = [
sigmas, model_wrap.inner_model.model_sampling.percent_to_sigma(
extra_args, torch.rand((1,)).item()
callback,
noise,
latent_image=None,
denoise_mask=None,
disable_pbar=False,
):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
ui_pbar = ProgressBar(self.total_steps)
for i in (
pbar := trange(
self.total_steps,
desc="Training LoRA",
smoothing=0.01,
disable=not comfy.utils.PROGRESS_BAR_ENABLED,
)
):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
self.seed + i * 1000
) )
for _ in range(batch_size)
]
return torch.tensor(batch_sigmas).to(device)
if self.bucket_latents is not None: def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar):
# Bucket mode: sample bucket (weighted by size), then sample batch from bucket """Execute one training step in bucket mode."""
# Sample bucket (weighted by size), then sample batch from bucket
bucket_idx = torch.multinomial(self.bucket_weights, 1).item() bucket_idx = torch.multinomial(self.bucket_weights, 1).item()
bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi) bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi)
bucket_size = bucket_latent.shape[0] bucket_size = bucket_latent.shape[0]
@ -177,13 +164,7 @@ class TrainSampler(comfy.samplers.Sampler):
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device batch_latent.device
) )
batch_sigmas = [ batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device)
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
).to(batch_latent.device)
for _ in range(actual_batch_size)
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
loss = self.fwd_bwd( loss = self.fwd_bwd(
model_wrap, model_wrap,
@ -200,19 +181,14 @@ class TrainSampler(comfy.samplers.Sampler):
self.loss_callback(loss.item()) self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx}) pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx})
elif self.real_dataset is None: def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in standard (non-bucket, non-multi-res) mode."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
batch_latent = torch.stack([latent_image[i] for i in indicies]) batch_latent = torch.stack([latent_image[i] for i in indicies])
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(
batch_latent.device batch_latent.device
) )
batch_sigmas = [ batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device)
model_wrap.inner_model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
for _ in range(min(self.batch_size, dataset_size))
]
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
loss = self.fwd_bwd( loss = self.fwd_bwd(
model_wrap, model_wrap,
@ -228,7 +204,9 @@ class TrainSampler(comfy.samplers.Sampler):
if self.loss_callback: if self.loss_callback:
self.loss_callback(loss.item()) self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"}) pbar.set_postfix({"loss": f"{loss.item():.4f}"})
else:
def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar):
"""Execute one training step in multi-resolution mode (real_dataset is set)."""
indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() indicies = torch.randperm(dataset_size)[: self.batch_size].tolist()
total_loss = 0 total_loss = 0
for index in indicies: for index in indicies:
@ -260,6 +238,41 @@ class TrainSampler(comfy.samplers.Sampler):
self.loss_callback(total_loss.item()) self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
def sample(
self,
model_wrap,
sigmas,
extra_args,
callback,
noise,
latent_image=None,
denoise_mask=None,
disable_pbar=False,
):
model_wrap.conds = process_cond_list(model_wrap.conds)
cond = model_wrap.conds["positive"]
dataset_size = sigmas.size(0)
torch.cuda.empty_cache()
ui_pbar = ProgressBar(self.total_steps)
for i in (
pbar := trange(
self.total_steps,
desc="Training LoRA",
smoothing=0.01,
disable=not comfy.utils.PROGRESS_BAR_ENABLED,
)
):
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(
self.seed + i * 1000
)
if self.bucket_latents is not None:
self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar)
elif self.real_dataset is None:
self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
else:
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0: if (i + 1) % self.grad_acc == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
@ -341,6 +354,360 @@ def unpatch(m):
del m.org_forward del m.org_forward
def _process_latents_bucket_mode(latents):
"""Process latents for bucket mode training.
Args:
latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
Returns:
list of latent tensors
"""
bucket_latents = []
for latent_dict in latents:
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
return bucket_latents
def _process_latents_standard_mode(latents):
"""Process latents for standard (non-bucket) mode training.
Args:
latents: list of latent dicts or single latent dict
Returns:
Processed latents (tensor or list of tensors)
"""
if len(latents) == 1:
return latents[0]["samples"] # Single latent dict
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
return latent_list
def _process_conditioning(positive):
"""Process conditioning - either single list or list of lists.
Args:
positive: list of conditioning
Returns:
Flattened conditioning list
"""
if len(positive) == 1:
return positive[0] # Single conditioning list
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
return flat_positive
def _prepare_latents_and_count(latents, dtype, bucket_mode):
"""Convert latents to dtype and compute image counts.
Args:
latents: Latents (tensor, list of tensors, or bucket list)
dtype: Target dtype
bucket_mode: Whether bucket mode is enabled
Returns:
tuple: (processed_latents, num_images, multi_res)
"""
if bucket_mode:
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi)
latents = [t.to(dtype) for t in latents]
num_buckets = len(latents)
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
for i, lat in enumerate(latents):
logging.info(f" Bucket {i}: shape {lat.shape}")
return latents, num_images, multi_res
# Non-bucket mode
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
multi_res = False
else:
logging.error(f"Invalid latents type: {type(latents)}")
num_images = 0
multi_res = False
return latents, num_images, multi_res
def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
"""Validate conditioning count matches image count, expand if needed.
Args:
positive: Conditioning list
num_images: Number of images
bucket_mode: Whether bucket mode is enabled
Returns:
Validated/expanded conditioning list
Raises:
ValueError: If conditioning count doesn't match image count
"""
if bucket_mode:
return positive # Skip validation in bucket mode
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
return positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
return positive
def _load_existing_lora(existing_lora):
"""Load existing LoRA weights if provided.
Args:
existing_lora: LoRA filename or "[None]"
Returns:
tuple: (existing_weights dict, existing_steps int)
"""
if existing_lora == "[None]":
return {}, 0
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
existing_weights = {}
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
return existing_weights, existing_steps
def _create_weight_adapter(module, module_name, existing_weights, algorithm, lora_dtype, rank):
"""Create a weight adapter for a module with weight.
Args:
module: The module to create adapter for
module_name: Name of the module
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (train_adapter, lora_params dict)
"""
key = f"{module_name}.weight"
shape = module.weight.shape
lora_params = {}
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
# Try to load existing adapter
existing_adapter = None
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
module_name, existing_weights, alpha, dora_scale
)
if existing_adapter is not None:
break
if existing_adapter is None:
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
module.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_params[f"{module_name}.{name}"] = parameter
return train_adapter, lora_params
else:
# 1D weight - use BiasDiff
diff = torch.nn.Parameter(
torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True)
)
diff_module = BiasDiff(diff)
lora_params[f"{module_name}.diff"] = diff
return diff_module, lora_params
def _create_bias_adapter(module, module_name, lora_dtype):
"""Create a bias adapter for a module with bias.
Args:
module: The module with bias
module_name: Name of the module
lora_dtype: dtype for LoRA weights
Returns:
tuple: (bias_module, lora_params dict)
"""
bias = torch.nn.Parameter(
torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True)
)
bias_module = BiasDiff(bias)
lora_params = {f"{module_name}.diff_b": bias}
return bias_module, lora_params
def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
"""Setup all LoRA adapters on the model.
Args:
mp: Model patcher
existing_weights: Dict of existing LoRA weights
algorithm: Algorithm name for new adapters
lora_dtype: dtype for LoRA weights
rank: Rank for new LoRA adapters
Returns:
tuple: (lora_sd dict, all_weight_adapters list)
"""
lora_sd = {}
all_weight_adapters = []
for n, m in mp.model.named_modules():
if hasattr(m, "weight_function"):
if m.weight is not None:
adapter, params = _create_weight_adapter(
m, n, existing_weights, algorithm, lora_dtype, rank
)
lora_sd.update(params)
key = f"{n}.weight"
mp.add_weight_wrapper(key, adapter)
all_weight_adapters.append(adapter)
if hasattr(m, "bias") and m.bias is not None:
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
lora_sd.update(bias_params)
key = f"{n}.bias"
mp.add_weight_wrapper(key, bias_adapter)
all_weight_adapters.append(bias_adapter)
return lora_sd, all_weight_adapters
def _create_optimizer(optimizer_name, parameters, learning_rate):
"""Create optimizer based on name.
Args:
optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop")
parameters: Parameters to optimize
learning_rate: Learning rate
Returns:
Optimizer instance
"""
if optimizer_name == "Adam":
return torch.optim.Adam(parameters, lr=learning_rate)
elif optimizer_name == "AdamW":
return torch.optim.AdamW(parameters, lr=learning_rate)
elif optimizer_name == "SGD":
return torch.optim.SGD(parameters, lr=learning_rate)
elif optimizer_name == "RMSprop":
return torch.optim.RMSprop(parameters, lr=learning_rate)
def _create_loss_function(loss_function_name):
"""Create loss function based on name.
Args:
loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1")
Returns:
Loss function instance
"""
if loss_function_name == "MSE":
return torch.nn.MSELoss()
elif loss_function_name == "L1":
return torch.nn.L1Loss()
elif loss_function_name == "Huber":
return torch.nn.HuberLoss()
elif loss_function_name == "SmoothL1":
return torch.nn.SmoothL1Loss()
def _run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res):
"""Execute the training loop.
Args:
guider: The guider object
train_sampler: The training sampler
latents: Latent tensors
num_images: Number of images
seed: Random seed
bucket_mode: Whether bucket mode is enabled
multi_res: Whether multi-resolution mode is enabled
"""
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if bucket_mode:
# Use first bucket's first latent as dummy for guider
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": dummy_latent}),
dummy_latent,
train_sampler,
sigmas,
seed=noise.seed,
)
elif multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
else:
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
class TrainLoraNode(io.ComfyNode): class TrainLoraNode(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -497,8 +864,8 @@ class TrainLoraNode(io.ComfyNode):
grad_accumulation_steps = grad_accumulation_steps[0] grad_accumulation_steps = grad_accumulation_steps[0]
learning_rate = learning_rate[0] learning_rate = learning_rate[0]
rank = rank[0] rank = rank[0]
optimizer = optimizer[0] optimizer_name = optimizer[0]
loss_function = loss_function[0] loss_function_name = loss_function[0]
seed = seed[0] seed = seed[0]
training_dtype = training_dtype[0] training_dtype = training_dtype[0]
lora_dtype = lora_dtype[0] lora_dtype = lora_dtype[0]
@ -507,182 +874,48 @@ class TrainLoraNode(io.ComfyNode):
existing_lora = existing_lora[0] existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0] bucket_mode = bucket_mode[0]
# Process latents based on mode
if bucket_mode: if bucket_mode:
# Bucket mode: latents and conditions are already bucketed latents = _process_latents_bucket_mode(latents)
# latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi)
# positive: list[list[cond]] where each inner list has Bi conditions
bucket_latents = []
for latent_dict in latents:
bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi)
latents = bucket_latents
else: else:
# Handle latents - either single dict or list of dicts latents = _process_latents_standard_mode(latents)
if len(latents) == 1:
latents = latents[0]["samples"] # Single latent dict
else:
latent_list = []
for latent in latents:
latent = latent["samples"]
bs = latent.shape[0]
if bs != 1:
for sub_latent in latent:
latent_list.append(sub_latent[None])
else:
latent_list.append(latent)
latents = latent_list
# Handle conditioning - either single list or list of lists # Process conditioning
if len(positive) == 1: positive = _process_conditioning(positive)
positive = positive[0] # Single conditioning list
else:
# Multiple conditioning lists - flatten
flat_positive = []
for cond in positive:
if isinstance(cond, list):
flat_positive.extend(cond)
else:
flat_positive.append(cond)
positive = flat_positive
# Setup model and dtype
mp = model.clone() mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype) dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype) mp.set_model_compute_dtype(dtype)
if bucket_mode: # Prepare latents and compute counts
# In bucket mode, latents is list of tensors (Bi, C, Hi, Wi) latents, num_images, multi_res = _prepare_latents_and_count(latents, dtype, bucket_mode)
# positive is list of condition lists
latents = [t.to(dtype) for t in latents]
num_buckets = len(latents)
num_images = sum(t.shape[0] for t in latents)
multi_res = False # Not using multi_res path in bucket mode
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") # Validate and expand conditioning
for i, lat in enumerate(latents): positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
logging.info(f" Bucket {i}: shape {lat.shape}")
else:
# latents here can be list of different size latent or one large batch
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
multi_res = True
else:
multi_res = False
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, torch.Tensor):
latents = latents.to(dtype)
num_images = latents.shape[0]
else:
logging.error(f"Invalid latents type: {type(latents)}")
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)
with torch.inference_mode(False): with torch.inference_mode(False):
lora_sd = {}
generator = torch.Generator()
generator.manual_seed(seed)
# Load existing LoRA weights if provided # Load existing LoRA weights if provided
existing_weights = {} existing_weights, existing_steps = _load_existing_lora(existing_lora)
existing_steps = 0
if existing_lora != "[None]":
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
if lora_path:
existing_weights = comfy.utils.load_torch_file(lora_path)
all_weight_adapters = [] # Setup LoRA adapters
for n, m in mp.model.named_modules(): lora_sd, all_weight_adapters = _setup_lora_adapters(
if hasattr(m, "weight_function"): mp, existing_weights, algorithm, lora_dtype, rank
if m.weight is not None:
key = "{}.weight".format(n)
shape = m.weight.shape
if len(shape) >= 2:
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
for adapter_cls in adapters:
existing_adapter = adapter_cls.load(
n, existing_weights, alpha, dora_scale
) )
if existing_adapter is not None:
break
else:
existing_adapter = None
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None: # Create optimizer and loss function
train_adapter = existing_adapter.to_train().to( optimizer = _create_optimizer(optimizer_name, lora_sd.values(), learning_rate)
lora_dtype criterion = _create_loss_function(loss_function_name)
)
else:
# Use LoRA with alpha=1.0 by default
train_adapter = adapter_cls.create_train(
m.weight, rank=rank, alpha=1.0
).to(lora_dtype)
for name, parameter in train_adapter.named_parameters():
lora_sd[f"{n}.{name}"] = parameter
mp.add_weight_wrapper(key, train_adapter) # Setup gradient checkpointing
all_weight_adapters.append(train_adapter)
else:
diff = torch.nn.Parameter(
torch.zeros(
m.weight.shape, dtype=lora_dtype, requires_grad=True
)
)
diff_module = BiasDiff(diff)
mp.add_weight_wrapper(key, BiasDiff(diff))
all_weight_adapters.append(diff_module)
lora_sd["{}.diff".format(n)] = diff
if hasattr(m, "bias") and m.bias is not None:
key = "{}.bias".format(n)
bias = torch.nn.Parameter(
torch.zeros(
m.bias.shape, dtype=lora_dtype, requires_grad=True
)
)
bias_module = BiasDiff(bias)
lora_sd["{}.diff_b".format(n)] = bias
mp.add_weight_wrapper(key, BiasDiff(bias))
all_weight_adapters.append(bias_module)
if optimizer == "Adam":
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
elif optimizer == "AdamW":
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
elif optimizer == "SGD":
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
elif optimizer == "RMSprop":
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
# Setup loss function based on selection
if loss_function == "MSE":
criterion = torch.nn.MSELoss()
elif loss_function == "L1":
criterion = torch.nn.L1Loss()
elif loss_function == "Huber":
criterion = torch.nn.HuberLoss()
elif loss_function == "SmoothL1":
criterion = torch.nn.SmoothL1Loss()
# setup models
if gradient_checkpointing: if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward( for m in find_all_highest_child_module_with_forward(
mp.model.diffusion_model mp.model.diffusion_model
): ):
patch(m) patch(m)
# Setup models for training
mp.model.requires_grad_(False) mp.model.requires_grad_(False)
torch.cuda.empty_cache() torch.cuda.empty_cache()
comfy.model_management.load_models_gpu( comfy.model_management.load_models_gpu(
@ -690,14 +923,14 @@ class TrainLoraNode(io.ComfyNode):
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Setup sampler and guider like in test script # Setup loss tracking
loss_map = {"loss": []} loss_map = {"loss": []}
def loss_callback(loss): def loss_callback(loss):
loss_map["loss"].append(loss) loss_map["loss"].append(loss)
# Create sampler
if bucket_mode: if bucket_mode:
# Bucket mode: pass bucket data to sampler
train_sampler = TrainSampler( train_sampler = TrainSampler(
criterion, criterion,
optimizer, optimizer,
@ -721,48 +954,20 @@ class TrainLoraNode(io.ComfyNode):
training_dtype=dtype, training_dtype=dtype,
real_dataset=latents if multi_res else None, real_dataset=latents if multi_res else None,
) )
# Setup guider
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
# In bucket mode we still send flatten positive to set_conds
guider.set_conds(positive) guider.set_conds(positive)
# Training loop # Run training loop
try: try:
# Generate dummy sigmas and noise _run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res)
sigmas = torch.tensor(range(num_images))
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if bucket_mode:
# Use first bucket's first latent as dummy for guider
dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": dummy_latent}),
dummy_latent,
train_sampler,
sigmas,
seed=noise.seed,
)
elif multi_res:
# use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1)
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
else:
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
train_sampler,
sigmas,
seed=noise.seed,
)
finally: finally:
for m in mp.model.modules(): for m in mp.model.modules():
unpatch(m) unpatch(m)
del train_sampler, optimizer del train_sampler, optimizer
# Finalize adapters
for adapter in all_weight_adapters: for adapter in all_weight_adapters:
adapter.requires_grad_(False) adapter.requires_grad_(False)