From 32b44f5d1c0b2ac9de81a172909d8bb23ff4eb41 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 3 Nov 2025 00:33:54 +0800 Subject: [PATCH] make training node to work with our dataset system --- comfy_extras/nodes_dataset.py | 52 +++++++++++++++++++++++++++-------- comfy_extras/nodes_train.py | 23 ++++++++++++++-- 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 1aa49a502..941dc02c9 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1,6 +1,7 @@ import logging import os import pickle +import math import numpy as np import torch @@ -332,6 +333,30 @@ class ResizeImagesToSameSizeNode(ImageProcessingNode): return output_images +class ResizeImagesToPixelCountNode(ImageProcessingNode): + DESCRIPTION = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." + + @classmethod + def INPUT_TYPES(cls): + base_inputs = super().INPUT_TYPES() + base_inputs["required"]["pixel_count"] = ("INT", {"default": 512 * 512, "min": 1, "max": 8192 * 8192, "step": 1, "tooltip": "Target pixel count."}) + base_inputs["required"]["steps"] = ("INT", {"default": 64, "min": 1, "max": 128, "step": 1, "tooltip": "The stepping for resize width/height."}) + return base_inputs + + def _process(self, images, pixel_count, steps): + output_images = [] + for img_tensor in images: + img = self._tensor_to_pil(img_tensor) + w, h = img.size + pixel_count_ratio = math.sqrt(pixel_count / (w * h)) + new_w = int(h * pixel_count_ratio / steps) * steps + new_h = int(w * pixel_count_ratio / steps) * steps + logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + output_images.append(self._pil_to_tensor(img)) + return output_images + + class ResizeImagesByShorterEdgeNode(ImageProcessingNode): DESCRIPTION = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." @@ -421,7 +446,7 @@ class RandomCropImagesNode(ImageProcessingNode): return base_inputs def _process(self, images, width, height, seed): - np.random.seed(seed) + np.random.seed(seed%(2**32-1)) output_images = [] for img_tensor in images: img = self._tensor_to_pil(img_tensor) @@ -509,7 +534,7 @@ class ShuffleDatasetNode(ImageProcessingNode): return base_inputs def _process(self, images, seed): - np.random.seed(seed) + np.random.seed(seed%(2**32-1)) indices = np.random.permutation(len(images)) return [images[i] for i in indices] @@ -534,7 +559,7 @@ class ShuffleImageTextDatasetNode: } def process(self, images, texts, seed): - np.random.seed(seed) + np.random.seed(seed%(2**32-1)) indices = np.random.permutation(len(images)) shuffled_images = [images[i] for i in indices] shuffled_texts = [texts[i] for i in indices] @@ -637,7 +662,7 @@ class MakeTrainingDataset: }, } - RETURN_TYPES = ("LATENT_LIST", "CONDITIONING") + RETURN_TYPES = ("LATENT", "CONDITIONING") RETURN_NAMES = ("latents", "conditioning") FUNCTION = "make_dataset" CATEGORY = "dataset" @@ -667,7 +692,8 @@ class MakeTrainingDataset: for img_tensor in images: # img_tensor is [1, H, W, 3] t = vae.encode(img_tensor[:,:,:,:3]) - latents.append({"samples": t}) + latents.append(t) + latents = {"samples": latents} # Encode texts with CLIP logging.info(f"Encoding {len(texts)} texts with CLIP...") @@ -681,7 +707,7 @@ class MakeTrainingDataset: cond = clip.encode_from_tokens_scheduled(tokens) conditions.extend(cond) - logging.info(f"Created dataset with {len(latents)} latents and {len(conditions)} conditions.") + logging.info(f"Created dataset with {len(latents['samples'])} latents and {len(conditions)} conditions.") return (latents, conditions) @@ -692,7 +718,7 @@ class SaveTrainingDataset: def INPUT_TYPES(s): return { "required": { - "latents": ("LATENT_LIST", {"tooltip": "List of latent tensors from MakeTrainingDataset."}), + "latents": ("LATENT", {"tooltip": "List of latent tensors from MakeTrainingDataset."}), "conditioning": ("CONDITIONING", {"tooltip": "Conditioning list from MakeTrainingDataset."}), "folder_name": ("STRING", {"default": "training_dataset", "tooltip": "Name of folder to save dataset (inside output directory)."}), "shard_size": ("INT", {"default": 1000, "min": 1, "max": 100000, "step": 1, "tooltip": "Number of samples per shard file."}), @@ -708,7 +734,7 @@ class SaveTrainingDataset: def save_dataset(self, latents, conditioning, folder_name, shard_size): # Validate lengths match - if len(latents) != len(conditioning): + if len(latents["samples"]) != len(conditioning): raise ValueError( f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " f"Something went wrong in dataset preparation." @@ -719,7 +745,7 @@ class SaveTrainingDataset: os.makedirs(output_dir, exist_ok=True) # Prepare data pairs - num_samples = len(latents) + num_samples = len(latents["samples"]) num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division logging.info(f"Saving {num_samples} samples to {num_shards} shards in {output_dir}...") @@ -731,7 +757,7 @@ class SaveTrainingDataset: # Get shard data shard_data = { - "latents": latents[start_idx:end_idx], + "latents": latents["samples"][start_idx:end_idx], "conditioning": conditioning[start_idx:end_idx], } @@ -770,7 +796,7 @@ class LoadTrainingDataset: }, } - RETURN_TYPES = ("LATENT_LIST", "CONDITIONING") + RETURN_TYPES = ("LATENT", "CONDITIONING") RETURN_NAMES = ("latents", "conditioning") FUNCTION = "load_dataset" CATEGORY = "dataset" @@ -811,7 +837,7 @@ class LoadTrainingDataset: logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") logging.info(f"Successfully loaded {len(all_latents)} samples from {dataset_dir}.") - return (all_latents, all_conditioning) + return ({"samples": all_latents}, all_conditioning) NODE_CLASS_MAPPINGS = { @@ -821,6 +847,7 @@ NODE_CLASS_MAPPINGS = { "SaveImageTextDataSetToFolderNode": SaveImageTextDataSetToFolderNode, # Image transforms "ResizeImagesToSameSizeNode": ResizeImagesToSameSizeNode, + "ResizeImagesToPixelCountNode": ResizeImagesToPixelCountNode, "ResizeImagesByShorterEdgeNode": ResizeImagesByShorterEdgeNode, "ResizeImagesByLongerEdgeNode": ResizeImagesByLongerEdgeNode, "CenterCropImagesNode": CenterCropImagesNode, @@ -852,6 +879,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveImageTextDataSetToFolderNode": "Save Simple Image and Text Dataset to Folder", # Image transforms "ResizeImagesToSameSizeNode": "Resize Images to Same Size", + "ResizeImagesToPixelCountNode": "Resize Images to Pixel Count", "ResizeImagesByShorterEdgeNode": "Resize Images by Shorter Edge", "ResizeImagesByLongerEdgeNode": "Resize Images by Longer Edge", "CenterCropImagesNode": "Center Crop Images", diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 9e6ec6780..b092dd0d7 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -558,8 +558,27 @@ class TrainLoraNode: lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - latents = latents["samples"].to(dtype) - num_images = latents.shape[0] + # latents here can be list of different size latent or one large batch + latents = latents["samples"] + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + raise ValueError( + "Different shapes latents are not currently supported" + ) + else: + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, list): + latents = latents["samples"].to(dtype) + num_images = latents.shape[0] + else: + logging.error(f"Invalid latents type: {type(latents)}") + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: positive = positive * num_images