diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 4789d7d53..2f1f304f8 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1,5 +1,6 @@ import logging import os +import math import json import numpy as np @@ -623,6 +624,79 @@ class TextProcessingNode(io.ComfyNode): # ========== Image Transform Nodes ========== +class ResizeImagesToSameSizeNode(ImageProcessingNode): + node_id = "ResizeImagesToSameSize" + display_name = "Resize Images to Same Size" + description = "Resize all images to the same width and height." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."), + io.Combo.Input( + "mode", + options=["stretch", "crop_center", "pad"], + default="stretch", + tooltip="Resize mode.", + ), + ] + + @classmethod + def _process(cls, image, width, height, mode): + img = tensor_to_pil(image) + + if mode == "stretch": + img = img.resize((width, height), Image.Resampling.LANCZOS) + elif mode == "crop_center": + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + if img.width != width or img.height != height: + img = img.resize((width, height), Image.Resampling.LANCZOS) + elif mode == "pad": + img.thumbnail((width, height), Image.Resampling.LANCZOS) + new_img = Image.new("RGB", (width, height), (0, 0, 0)) + paste_x = (width - img.width) // 2 + paste_y = (height - img.height) // 2 + new_img.paste(img, (paste_x, paste_y)) + img = new_img + + return pil_to_tensor(img) + + +class ResizeImagesToPixelCountNode(ImageProcessingNode): + node_id = "ResizeImagesToPixelCount" + display_name = "Resize Images to Pixel Count" + description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "pixel_count", + default=512 * 512, + min=1, + max=8192 * 8192, + tooltip="Target pixel count.", + ), + io.Int.Input( + "steps", + default=64, + min=1, + max=128, + tooltip="The stepping for resize width/height.", + ), + ] + + @classmethod + def _process(cls, image, pixel_count, steps): + img = tensor_to_pil(image) + w, h = img.size + pixel_count_ratio = math.sqrt(pixel_count / (w * h)) + new_w = int(w * pixel_count_ratio / steps) * steps + new_h = int(h * pixel_count_ratio / steps) * steps + logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + class ResizeImagesByShorterEdgeNode(ImageProcessingNode): node_id = "ResizeImagesByShorterEdge" display_name = "Resize Images by Shorter Edge" @@ -727,6 +801,29 @@ class RandomCropImagesNode(ImageProcessingNode): return pil_to_tensor(img) +class FlipImagesNode(ImageProcessingNode): + node_id = "FlipImages" + display_name = "Flip Images" + description = "Flip all images horizontally or vertically." + extra_inputs = [ + io.Combo.Input( + "direction", + options=["horizontal", "vertical"], + default="horizontal", + tooltip="Flip direction.", + ), + ] + + @classmethod + def _process(cls, image, direction): + img = tensor_to_pil(image) + if direction == "horizontal": + img = img.transpose(Image.FLIP_LEFT_RIGHT) + else: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + return pil_to_tensor(img) + + class NormalizeImagesNode(ImageProcessingNode): node_id = "NormalizeImages" display_name = "Normalize Images" @@ -1125,6 +1222,99 @@ class MergeTextListsNode(TextProcessingNode): # ========== Training Dataset Nodes ========== +class ResolutionBucket(io.ComfyNode): + """Bucket latents and conditions by resolution for efficient batch training.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ResolutionBucket", + display_name="Resolution Bucket", + category="dataset", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts to bucket by resolution.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists (must match latents length).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of batched latent dicts, one per resolution bucket.", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of condition lists, one per resolution bucket.", + ), + ], + ) + + @classmethod + def execute(cls, latents, conditioning): + # latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1 + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})." + ) + + # Flatten latents and conditions to individual samples + flat_latents = [] # list of (C, H, W) tensors + flat_conditions = [] # list of condition lists + + for latent_dict, cond in zip(latents, conditioning): + samples = latent_dict["samples"] # (B, C, H, W) + batch_size = samples.shape[0] + + # cond is a list of conditions with length == batch_size + for i in range(batch_size): + flat_latents.append(samples[i]) # (C, H, W) + flat_conditions.append(cond[i]) # single condition + + # Group by resolution (H, W) + buckets = {} # (H, W) -> {"latents": list, "conditions": list} + + for latent, cond in zip(flat_latents, flat_conditions): + # latent shape is (C, H, W) + h, w = latent.shape[1], latent.shape[2] + key = (h, w) + + if key not in buckets: + buckets[key] = {"latents": [], "conditions": []} + + buckets[key]["latents"].append(latent) + buckets[key]["conditions"].append(cond) + + # Convert buckets to output format + output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, C, H, W) + output_conditions = [] # list[list[cond]] where each inner list has Bi conditions + + for (h, w), bucket_data in buckets.items(): + # Stack latents into batch: list of (C, H, W) -> (Bi, C, H, W) + stacked_latents = torch.stack(bucket_data["latents"], dim=0) + output_latents.append({"samples": stacked_latents}) + + # Conditions stay as list of condition lists + output_conditions.append(bucket_data["conditions"]) + + logging.info( + f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples" + ) + + logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples") + return io.NodeOutput(output_latents, output_conditions) + + class MakeTrainingDataset(io.ComfyNode): """Encode images with VAE and texts with CLIP to create a training dataset.""" @@ -1373,7 +1563,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f, weights_only=True) + shard_data = torch.load(f) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) @@ -1399,10 +1589,13 @@ class DatasetExtension(ComfyExtension): SaveImageDataSetToFolderNode, SaveImageTextDataSetToFolderNode, # Image transform nodes + ResizeImagesToSameSizeNode, + ResizeImagesToPixelCountNode, ResizeImagesByShorterEdgeNode, ResizeImagesByLongerEdgeNode, CenterCropImagesNode, RandomCropImagesNode, + FlipImagesNode, NormalizeImagesNode, AdjustBrightnessNode, AdjustContrastNode, @@ -1425,6 +1618,7 @@ class DatasetExtension(ComfyExtension): MakeTrainingDataset, SaveTrainingDataset, LoadTrainingDataset, + ResolutionBucket, ] diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index cb24ab709..67cd0a200 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -65,6 +65,7 @@ class TrainSampler(comfy.samplers.Sampler): seed=0, training_dtype=torch.bfloat16, real_dataset=None, + bucket_latents=None, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -75,6 +76,22 @@ class TrainSampler(comfy.samplers.Sampler): self.seed = seed self.training_dtype = training_dtype self.real_dataset: list[torch.Tensor] | None = real_dataset + # Bucket mode data + self.bucket_latents: list[torch.Tensor] | None = bucket_latents # list of (Bi, C, Hi, Wi) + # Precompute bucket offsets and weights for sampling + if bucket_latents is not None: + self.bucket_offsets = [0] + bucket_sizes = [] + for lat in bucket_latents: + bucket_sizes.append(lat.shape[0]) + self.bucket_offsets.append(self.bucket_offsets[-1] + lat.shape[0]) + self.num_images = self.bucket_offsets[-1] + # Weights for sampling buckets proportional to their size + self.bucket_weights = torch.tensor(bucket_sizes, dtype=torch.float32) + else: + self.bucket_offsets = None + self.bucket_weights = None + self.num_images = None def fwd_bwd( self, @@ -142,9 +159,49 @@ class TrainSampler(comfy.samplers.Sampler): noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise( self.seed + i * 1000 ) - indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - if self.real_dataset is None: + if self.bucket_latents is not None: + # Bucket mode: sample bucket (weighted by size), then sample batch from bucket + bucket_idx = torch.multinomial(self.bucket_weights, 1).item() + bucket_latent = self.bucket_latents[bucket_idx] # (Bi, C, Hi, Wi) + bucket_size = bucket_latent.shape[0] + bucket_offset = self.bucket_offsets[bucket_idx] + + # Sample indices from this bucket (use all if bucket_size < batch_size) + actual_batch_size = min(self.batch_size, bucket_size) + relative_indices = torch.randperm(bucket_size)[:actual_batch_size].tolist() + # Convert to absolute indices for fwd_bwd (cond is flattened, use absolute index) + absolute_indices = [bucket_offset + idx for idx in relative_indices] + + batch_latent = bucket_latent[relative_indices].to(latent_image) # (actual_batch_size, C, H, W) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ).to(batch_latent.device) + for _ in range(actual_batch_size) + ] + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) + + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, # Use flattened cond with absolute indices + absolute_indices, + extra_args, + self.num_images, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}", "bucket": bucket_idx}) + + elif self.real_dataset is None: + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() batch_latent = torch.stack([latent_image[i] for i in indicies]) batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( batch_latent.device @@ -172,6 +229,7 @@ class TrainSampler(comfy.samplers.Sampler): self.loss_callback(loss.item()) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) else: + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() total_loss = 0 for index in indicies: single_latent = self.real_dataset[index].to(latent_image) @@ -385,6 +443,16 @@ class TrainLoraNode(io.ComfyNode): default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA.", ), + io.Boolean.Input( + "bucket_mode", + default=False, + tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", + ), + io.Boolean.Input( + "offloading", + default=False, + tooltip="", + ), ], outputs=[ io.Model.Output( @@ -419,6 +487,8 @@ class TrainLoraNode(io.ComfyNode): algorithm, gradient_checkpointing, existing_lora, + bucket_mode, + offloading, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -435,21 +505,31 @@ class TrainLoraNode(io.ComfyNode): algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] existing_lora = existing_lora[0] + bucket_mode = bucket_mode[0] - # Handle latents - either single dict or list of dicts - if len(latents) == 1: - latents = latents[0]["samples"] # Single latent dict + if bucket_mode: + # Bucket mode: latents and conditions are already bucketed + # latents: list[{"samples": tensor}] where each tensor is (Bi, C, Hi, Wi) + # positive: list[list[cond]] where each inner list has Bi conditions + bucket_latents = [] + for latent_dict in latents: + bucket_latents.append(latent_dict["samples"]) # (Bi, C, Hi, Wi) + latents = bucket_latents else: - latent_list = [] - for latent in latents: - latent = latent["samples"] - bs = latent.shape[0] - if bs != 1: - for sub_latent in latent: - latent_list.append(sub_latent[None]) - else: - latent_list.append(latent) - latents = latent_list + # Handle latents - either single dict or list of dicts + if len(latents) == 1: + latents = latents[0]["samples"] # Single latent dict + else: + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + latents = latent_list # Handle conditioning - either single list or list of lists if len(positive) == 1: @@ -469,32 +549,44 @@ class TrainLoraNode(io.ComfyNode): lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - # latents here can be list of different size latent or one large batch - if isinstance(latents, list): - all_shapes = set() + if bucket_mode: + # In bucket mode, latents is list of tensors (Bi, C, Hi, Wi) + # positive is list of condition lists latents = [t.to(dtype) for t in latents] - for latent in latents: - all_shapes.add(latent.shape) - logging.info(f"Latent shapes: {all_shapes}") - if len(all_shapes) > 1: - multi_res = True - else: - multi_res = False - latents = torch.cat(latents, dim=0) - num_images = len(latents) - elif isinstance(latents, torch.Tensor): - latents = latents.to(dtype) - num_images = latents.shape[0] - else: - logging.error(f"Invalid latents type: {type(latents)}") + num_buckets = len(latents) + num_images = sum(t.shape[0] for t in latents) + multi_res = False # Not using multi_res path in bucket mode - logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") - if len(positive) == 1 and num_images > 1: - positive = positive * num_images - elif len(positive) != num_images: - raise ValueError( - f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." - ) + logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") + for i, lat in enumerate(latents): + logging.info(f" Bucket {i}: shape {lat.shape}") + else: + # latents here can be list of different size latent or one large batch + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + else: + logging.error(f"Invalid latents type: {type(latents)}") + + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + if len(positive) == 1 and num_images > 1: + positive = positive * num_images + elif len(positive) != num_images: + raise ValueError( + f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." + ) with torch.inference_mode(False): lora_sd = {} @@ -592,9 +684,11 @@ class TrainLoraNode(io.ComfyNode): ): patch(m) mp.model.requires_grad_(False) + torch.cuda.empty_cache() comfy.model_management.load_models_gpu( - [mp], memory_required=1e20, force_full_load=True + [mp], memory_required=1e20, force_full_load=not offloading ) + torch.cuda.empty_cache() # Setup sampler and guider like in test script loss_map = {"loss": []} @@ -602,35 +696,68 @@ class TrainLoraNode(io.ComfyNode): def loss_callback(loss): loss_map["loss"].append(loss) - train_sampler = TrainSampler( - criterion, - optimizer, - loss_callback=loss_callback, - batch_size=batch_size, - grad_acc=grad_accumulation_steps, - total_steps=steps * grad_accumulation_steps, - seed=seed, - training_dtype=dtype, - real_dataset=latents if multi_res else None, - ) + if bucket_mode: + # Bucket mode: pass bucket data to sampler + train_sampler = TrainSampler( + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + grad_acc=grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, + seed=seed, + training_dtype=dtype, + bucket_latents=latents, + ) + else: + train_sampler = TrainSampler( + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + grad_acc=grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, + seed=seed, + training_dtype=dtype, + real_dataset=latents if multi_res else None, + ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) - guider.set_conds(positive) # Set conditioning from input + # In bucket mode we still send flatten positive to set_conds + guider.set_conds(positive) # Training loop try: # Generate dummy sigmas and noise sigmas = torch.tensor(range(num_images)) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) - if multi_res: + if bucket_mode: + # Use first bucket's first latent as dummy for guider + dummy_latent = latents[0][:1].repeat(num_images, 1, 1, 1) + guider.sample( + noise.generate_noise({"samples": dummy_latent}), + dummy_latent, + train_sampler, + sigmas, + seed=noise.seed, + ) + elif multi_res: # use first latent as dummy latent if multi_res latents = latents[0].repeat(num_images, 1, 1, 1) - guider.sample( - noise.generate_noise({"samples": latents}), - latents, - train_sampler, - sigmas, - seed=noise.seed, - ) + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) + else: + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed, + ) finally: for m in mp.model.modules(): unpatch(m)