diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 67cd0a200..a24d3b199 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -80,19 +80,23 @@ class TrainSampler(comfy.samplers.Sampler): self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi) # Precompute bucket offsets and weights for sampling if bucket_latents is not None: - self.bucket_offsets = [0] - bucket_sizes = [] - for lat in bucket_latents: - bucket_sizes.append(lat.shape[0]) - self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0]) - self.num_images = self.bucket_offsets[-1] - # Weights for sampling buckets proportional to their size - self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32) + self._init_bucket_data(bucket_latents) else: self.bucket_offsets = None self.bucket_weights = None self.num_images = None + def _init_bucket_data(self, bucket_latents): + """Initialize bucket offsets and weights for sampling.""" + self.bucket_offsets = [0] + bucket_sizes = [] + for lat in bucket_latents: + bucket_sizes.append(lat.shape[0]) + self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0]) + self.num_images = self.bucket_offsets[-1] + # Weights for sampling buckets proportional to their size + self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32) + def fwd_bwd( self, model_wrap, @@ -132,6 +136,108 @@ class TrainSampler(comfy.samplers.Sampler): bwd_loss.backward() return loss + def _generate_batch_sigmas(self, model_wrap, batch_size, device): + """Generate random sigma values for a batch.""" + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + for _ in range(batch_size) + ] + return torch.tensor(batch_sigmas).to(device) + + def _train_step_bucket_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, pbar): + """Execute one training step in bucket mode.""" + # Sample bucket (weighted by size), then sample batch from bucket + bucket_idx = torch.multinomial(self.bucket_weights, 1).item() + bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi) + bucket_size = bucket_latent.shape[0] + bucket_offset = self.bucket_offsets[bucket_idx] + + # Sample indices from this bucket (use all if bucket_size < batch_size) + actual_batch_size = min(self.batch_size, bucket_size) + relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist() + # Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index) + absolute_indices = [bucket_offset + idx for idx in relative_indices] + + batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = self._generate_batch_sigmas(model_wrap, actual_batch_size, batch_latent.device) + + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, # Use flattened cond with absolute indices + absolute_indices, + extra_args, + self.num_images, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx}) + + def _train_step_standard_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar): + """Execute one training step in standard (non-bucket, non-multi-res) mode.""" + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = self._generate_batch_sigmas(model_wrap, min(self.batch_size, dataset_size), batch_latent.device) + + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + def _train_step_multires_mode(self, model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar): + """Execute one training step in multi-resolution mode (real_dataset is set).""" + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() + total_loss = 0 + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise( + {"samples": single_latent} + ).to(single_latent.device) + batch_sigmas = ( + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + single_latent, + cond, + [index], + extra_args, + dataset_size, + bwd=False, + ) + total_loss += loss + total_loss = total_loss / self.grad_acc / len(indicies) + total_loss.backward() + if self.loss_callback: + self.loss_callback(total_loss.item()) + pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) + def sample( self, model_wrap, @@ -161,104 +267,11 @@ class TrainSampler(comfy.samplers.Sampler): ) if self.bucket_latents is not None: - # Bucket mode: sample bucket (weighted by size), then sample batch from bucket - bucket_idx = torch.multinomial(self.bucket_weights, 1).item() - bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi) - bucket_size = bucket_latent.shape[0] - bucket_offset = self.bucket_offsets[bucket_idx] - - # Sample indices from this bucket (use all if bucket_size < batch_size) - actual_batch_size = min(self.batch_size, bucket_size) - relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist() - # Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index) - absolute_indices = [bucket_offset + idx for idx in relative_indices] - - batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( - batch_latent.device - ) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ).to(batch_latent.device) - for _ in range(actual_batch_size) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - - loss = self.fwd_bwd( - model_wrap, - batch_sigmas, - batch_noise, - batch_latent, - cond, # Use flattened cond with absolute indices - absolute_indices, - extra_args, - self.num_images, - bwd=True, - ) - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx}) - + self._train_step_bucket_mode(model_wrap, cond, extra_args, noisegen, latent_image, pbar) elif self.real_dataset is None: - indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - batch_latent = torch.stack([latent_image[i] for i in indicies]) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( - batch_latent.device - ) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) - for _ in range(min(self.batch_size, dataset_size)) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - - loss = self.fwd_bwd( - model_wrap, - batch_sigmas, - batch_noise, - batch_latent, - cond, - indicies, - extra_args, - dataset_size, - bwd=True, - ) - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + self._train_step_standard_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) else: - indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - total_loss = 0 - for index in indicies: - single_latent = self.real_dataset[index].to(latent_image) - batch_noise = noisegen.generate_noise( - {"samples": single_latent} - ).to(single_latent.device) - batch_sigmas = ( - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) - ) - batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) - loss = self.fwd_bwd( - model_wrap, - batch_sigmas, - batch_noise, - single_latent, - cond, - [index], - extra_args, - dataset_size, - bwd=False, - ) - total_loss += loss - total_loss = total_loss / self.grad_acc / len(indicies) - total_loss.backward() - if self.loss_callback: - self.loss_callback(total_loss.item()) - pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) + self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: self.optimizer.step() @@ -341,6 +354,360 @@ def unpatch(m): del m.org_forward +def _process_latents_bucket_mode(latents): + """Process latents for bucket mode training. + + Args: + latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi) + + Returns: + list of latent tensors + """ + bucket_latents = [] + for latent_dict in latents: + bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi) + return bucket_latents + + +def _process_latents_standard_mode(latents): + """Process latents for standard (non-bucket) mode training. + + Args: + latents: list of latent dicts or single latent dict + + Returns: + Processed latents (tensor or list of tensors) + """ + if len(latents) == 1: + return latents[0]["samples"] # Single latent dict + + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + return latent_list + + +def _process_conditioning(positive): + """Process conditioning - either single list or list of lists. + + Args: + positive: list of conditioning + + Returns: + Flattened conditioning list + """ + if len(positive) == 1: + return positive[0] # Single conditioning list + + # Multiple conditioning lists - flatten + flat_positive = [] + for cond in positive: + if isinstance(cond, list): + flat_positive.extend(cond) + else: + flat_positive.append(cond) + return flat_positive + + +def _prepare_latents_and_count(latents, dtype, bucket_mode): + """Convert latents to dtype and compute image counts. + + Args: + latents: Latents (tensor, list of tensors, or bucket list) + dtype: Target dtype + bucket_mode: Whether bucket mode is enabled + + Returns: + tuple: (processed_latents, num_images, multi_res) + """ + if bucket_mode: + # In bucket mode, latents is list of tensors (Bi, C, Hi, Wi) + latents = [t.to(dtype) for t in latents] + num_buckets = len(latents) + num_images = sum(t.shape[0] for t in latents) + multi_res = False # Not using multi_res path in bucket mode + + logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") + for i, lat in enumerate(latents): + logging.info(f" Bucket {i}: shape {lat.shape}") + return latents, num_images, multi_res + + # Non-bucket mode + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + multi_res = False + else: + logging.error(f"Invalid latents type: {type(latents)}") + num_images = 0 + multi_res = False + + return latents, num_images, multi_res + + +def _validate_and_expand_conditioning(positive, num_images, bucket_mode): + """Validate conditioning count matches image count, expand if needed. + + Args: + positive: Conditioning list + num_images: Number of images + bucket_mode: Whether bucket mode is enabled + + Returns: + Validated/expanded conditioning list + + Raises: + ValueError: If conditioning count doesn't match image count + """ + if bucket_mode: + return positive # Skip validation in bucket mode + + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + if len(positive) == 1 and num_images > 1: + return positive * num_images + elif len(positive) != num_images: + raise ValueError( + f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." + ) + return positive + + +def _load_existing_lora(existing_lora): + """Load existing LoRA weights if provided. + + Args: + existing_lora: LoRA filename or "[None]" + + Returns: + tuple: (existing_weights dict, existing_steps int) + """ + if existing_lora == "[None]": + return {}, 0 + + lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) + # Extract steps from filename like "trained_lora_10_steps_20250225_203716" + existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) + existing_weights = {} + if lora_path: + existing_weights = comfy.utils.load_torch_file(lora_path) + return existing_weights, existing_steps + + +def _create_weight_adapter(module, module_name, existing_weights, algorithm, lora_dtype, rank): + """Create a weight adapter for a module with weight. + + Args: + module: The module to create adapter for + module_name: Name of the module + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (train_adapter, lora_params dict) + """ + key = f"{module_name}.weight" + shape = module.weight.shape + lora_params = {} + + if len(shape) >= 2: + alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) + dora_scale = existing_weights.get(f"{key}.dora_scale", None) + + # Try to load existing adapter + existing_adapter = None + for adapter_cls in adapters: + existing_adapter = adapter_cls.load( + module_name, existing_weights, alpha, dora_scale + ) + if existing_adapter is not None: + break + + if existing_adapter is None: + adapter_cls = adapter_maps[algorithm] + + if existing_adapter is not None: + train_adapter = existing_adapter.to_train().to(lora_dtype) + else: + # Use LoRA with alpha=1.0 by default + train_adapter = adapter_cls.create_train( + module.weight, rank=rank, alpha=1.0 + ).to(lora_dtype) + + for name, parameter in train_adapter.named_parameters(): + lora_params[f"{module_name}.{name}"] = parameter + + return train_adapter, lora_params + else: + # 1D weight - use BiasDiff + diff = torch.nn.Parameter( + torch.zeros(module.weight.shape, dtype=lora_dtype, requires_grad=True) + ) + diff_module = BiasDiff(diff) + lora_params[f"{module_name}.diff"] = diff + return diff_module, lora_params + + +def _create_bias_adapter(module, module_name, lora_dtype): + """Create a bias adapter for a module with bias. + + Args: + module: The module with bias + module_name: Name of the module + lora_dtype: dtype for LoRA weights + + Returns: + tuple: (bias_module, lora_params dict) + """ + bias = torch.nn.Parameter( + torch.zeros(module.bias.shape, dtype=lora_dtype, requires_grad=True) + ) + bias_module = BiasDiff(bias) + lora_params = {f"{module_name}.diff_b": bias} + return bias_module, lora_params + + +def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank): + """Setup all LoRA adapters on the model. + + Args: + mp: Model patcher + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (lora_sd dict, all_weight_adapters list) + """ + lora_sd = {} + all_weight_adapters = [] + + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + adapter, params = _create_weight_adapter( + m, n, existing_weights, algorithm, lora_dtype, rank + ) + lora_sd.update(params) + key = f"{n}.weight" + mp.add_weight_wrapper(key, adapter) + all_weight_adapters.append(adapter) + + if hasattr(m, "bias") and m.bias is not None: + bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype) + lora_sd.update(bias_params) + key = f"{n}.bias" + mp.add_weight_wrapper(key, bias_adapter) + all_weight_adapters.append(bias_adapter) + + return lora_sd, all_weight_adapters + + +def _create_optimizer(optimizer_name, parameters, learning_rate): + """Create optimizer based on name. + + Args: + optimizer_name: Name of optimizer ("Adam", "AdamW", "SGD", "RMSprop") + parameters: Parameters to optimize + learning_rate: Learning rate + + Returns: + Optimizer instance + """ + if optimizer_name == "Adam": + return torch.optim.Adam(parameters, lr=learning_rate) + elif optimizer_name == "AdamW": + return torch.optim.AdamW(parameters, lr=learning_rate) + elif optimizer_name == "SGD": + return torch.optim.SGD(parameters, lr=learning_rate) + elif optimizer_name == "RMSprop": + return torch.optim.RMSprop(parameters, lr=learning_rate) + + +def _create_loss_function(loss_function_name): + """Create loss function based on name. + + Args: + loss_function_name: Name of loss function ("MSE", "L1", "Huber", "SmoothL1") + + Returns: + Loss function instance + """ + if loss_function_name == "MSE": + return torch.nn.MSELoss() + elif loss_function_name == "L1": + return torch.nn.L1Loss() + elif loss_function_name == "Huber": + return torch.nn.HuberLoss() + elif loss_function_name == "SmoothL1": + return torch.nn.SmoothL1Loss() + + +def _run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res): + """Execute the training loop. + + Args: + guider: The guider object + train_sampler: The training sampler + latents: Latent tensors + num_images: Number of images + seed: Random seed + bucket_mode: Whether bucket mode is enabled + multi_res: Whether multi-resolution mode is enabled + """ + sigmas = torch.tensor(range(num_images)) + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + + if bucket_mode: + # Use first bucket's first latent as dummy for guider + dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1) + guider.sample( + noise.generate_noise({"samples": dummy_latent}), + dummy_latent, + train_sampler, + sigmas, + seed=noise.seed, + ) + elif multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat(num_images, 1, 1, 1) + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) + else: + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) + + class TrainLoraNode(io.ComfyNode): @classmethod def define_schema(cls): @@ -497,8 +864,8 @@ class TrainLoraNode(io.ComfyNode): grad_accumulation_steps = grad_accumulation_steps[0] learning_rate = learning_rate[0] rank = rank[0] - optimizer = optimizer[0] - loss_function = loss_function[0] + optimizer_name = optimizer[0] + loss_function_name = loss_function[0] seed = seed[0] training_dtype = training_dtype[0] lora_dtype = lora_dtype[0] @@ -507,182 +874,48 @@ class TrainLoraNode(io.ComfyNode): existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] + # Process latents based on mode if bucket_mode: - # Bucket mode: latents and conditions are already bucketed - # latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi) - # positive: list[list[cond]] where each inner list has Bi conditions - bucket_latents = [] - for latent_dict in latents: - bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi) - latents = bucket_latents + latents = _process_latents_bucket_mode(latents) else: - # Handle latents - either single dict or list of dicts - if len(latents) == 1: - latents = latents[0]["samples"] # Single latent dict - else: - latent_list = [] - for latent in latents: - latent = latent["samples"] - bs = latent.shape[0] - if bs != 1: - for sub_latent in latent: - latent_list.append(sub_latent[None]) - else: - latent_list.append(latent) - latents = latent_list + latents = _process_latents_standard_mode(latents) - # Handle conditioning - either single list or list of lists - if len(positive) == 1: - positive = positive[0] # Single conditioning list - else: - # Multiple conditioning lists - flatten - flat_positive = [] - for cond in positive: - if isinstance(cond, list): - flat_positive.extend(cond) - else: - flat_positive.append(cond) - positive = flat_positive + # Process conditioning + positive = _process_conditioning(positive) + # Setup model and dtype mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - if bucket_mode: - # In bucket mode, latents is list of tensors (Bi, C, Hi, Wi) - # positive is list of condition lists - latents = [t.to(dtype) for t in latents] - num_buckets = len(latents) - num_images = sum(t.shape[0] for t in latents) - multi_res = False # Not using multi_res path in bucket mode + # Prepare latents and compute counts + latents, num_images, multi_res = _prepare_latents_and_count(latents, dtype, bucket_mode) - logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") - for i, lat in enumerate(latents): - logging.info(f" Bucket {i}: shape {lat.shape}") - else: - # latents here can be list of different size latent or one large batch - if isinstance(latents, list): - all_shapes = set() - latents = [t.to(dtype) for t in latents] - for latent in latents: - all_shapes.add(latent.shape) - logging.info(f"Latent shapes: {all_shapes}") - if len(all_shapes) > 1: - multi_res = True - else: - multi_res = False - latents = torch.cat(latents, dim=0) - num_images = len(latents) - elif isinstance(latents, torch.Tensor): - latents = latents.to(dtype) - num_images = latents.shape[0] - else: - logging.error(f"Invalid latents type: {type(latents)}") - - logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") - if len(positive) == 1 and num_images > 1: - positive = positive * num_images - elif len(positive) != num_images: - raise ValueError( - f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." - ) + # Validate and expand conditioning + positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) with torch.inference_mode(False): - lora_sd = {} - generator = torch.Generator() - generator.manual_seed(seed) - # Load existing LoRA weights if provided - existing_weights = {} - existing_steps = 0 - if existing_lora != "[None]": - lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) - # Extract steps from filename like "trained_lora_10_steps_20250225_203716" - existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) - if lora_path: - existing_weights = comfy.utils.load_torch_file(lora_path) + existing_weights, existing_steps = _load_existing_lora(existing_lora) - all_weight_adapters = [] - for n, m in mp.model.named_modules(): - if hasattr(m, "weight_function"): - if m.weight is not None: - key = "{}.weight".format(n) - shape = m.weight.shape - if len(shape) >= 2: - alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) - dora_scale = existing_weights.get(f"{key}.dora_scale", None) - for adapter_cls in adapters: - existing_adapter = adapter_cls.load( - n, existing_weights, alpha, dora_scale - ) - if existing_adapter is not None: - break - else: - existing_adapter = None - adapter_cls = adapter_maps[algorithm] + # Setup LoRA adapters + lora_sd, all_weight_adapters = _setup_lora_adapters( + mp, existing_weights, algorithm, lora_dtype, rank + ) - if existing_adapter is not None: - train_adapter = existing_adapter.to_train().to( - lora_dtype - ) - else: - # Use LoRA with alpha=1.0 by default - train_adapter = adapter_cls.create_train( - m.weight, rank=rank, alpha=1.0 - ).to(lora_dtype) - for name, parameter in train_adapter.named_parameters(): - lora_sd[f"{n}.{name}"] = parameter + # Create optimizer and loss function + optimizer = _create_optimizer(optimizer_name, lora_sd.values(), learning_rate) + criterion = _create_loss_function(loss_function_name) - mp.add_weight_wrapper(key, train_adapter) - all_weight_adapters.append(train_adapter) - else: - diff = torch.nn.Parameter( - torch.zeros( - m.weight.shape, dtype=lora_dtype, requires_grad=True - ) - ) - diff_module = BiasDiff(diff) - mp.add_weight_wrapper(key, BiasDiff(diff)) - all_weight_adapters.append(diff_module) - lora_sd["{}.diff".format(n)] = diff - if hasattr(m, "bias") and m.bias is not None: - key = "{}.bias".format(n) - bias = torch.nn.Parameter( - torch.zeros( - m.bias.shape, dtype=lora_dtype, requires_grad=True - ) - ) - bias_module = BiasDiff(bias) - lora_sd["{}.diff_b".format(n)] = bias - mp.add_weight_wrapper(key, BiasDiff(bias)) - all_weight_adapters.append(bias_module) - - if optimizer == "Adam": - optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) - elif optimizer == "AdamW": - optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) - elif optimizer == "SGD": - optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) - elif optimizer == "RMSprop": - optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) - - # Setup loss function based on selection - if loss_function == "MSE": - criterion = torch.nn.MSELoss() - elif loss_function == "L1": - criterion = torch.nn.L1Loss() - elif loss_function == "Huber": - criterion = torch.nn.HuberLoss() - elif loss_function == "SmoothL1": - criterion = torch.nn.SmoothL1Loss() - - # setup models + # Setup gradient checkpointing if gradient_checkpointing: for m in find_all_highest_child_module_with_forward( mp.model.diffusion_model ): patch(m) + + # Setup models for training mp.model.requires_grad_(False) torch.cuda.empty_cache() comfy.model_management.load_models_gpu( @@ -690,14 +923,14 @@ class TrainLoraNode(io.ComfyNode): ) torch.cuda.empty_cache() - # Setup sampler and guider like in test script + # Setup loss tracking loss_map = {"loss": []} def loss_callback(loss): loss_map["loss"].append(loss) + # Create sampler if bucket_mode: - # Bucket mode: pass bucket data to sampler train_sampler = TrainSampler( criterion, optimizer, @@ -721,48 +954,20 @@ class TrainLoraNode(io.ComfyNode): training_dtype=dtype, real_dataset=latents if multi_res else None, ) + + # Setup guider guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) - # In bucket mode we still send flatten positive to set_conds guider.set_conds(positive) - # Training loop + # Run training loop try: - # Generate dummy sigmas and noise - sigmas = torch.tensor(range(num_images)) - noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) - if bucket_mode: - # Use first bucket's first latent as dummy for guider - dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1) - guider.sample( - noise.generate_noise({"samples": dummy_latent}), - dummy_latent, - train_sampler, - sigmas, - seed=noise.seed, - ) - elif multi_res: - # use first latent as dummy latent if multi_res - latents = latents[0].repeat(num_images, 1, 1, 1) - guider.sample( - noise.generate_noise({"samples": latents}), - latents, - train_sampler, - sigmas, - seed=noise.seed, - ) - else: - guider.sample( - noise.generate_noise({"samples": latents}), - latents, - train_sampler, - sigmas, - seed=noise.seed, - ) + _run_training_loop(guider, train_sampler, latents, num_images, seed, bucket_mode, multi_res) finally: for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer + # Finalize adapters for adapter in all_weight_adapters: adapter.requires_grad_(False)