From d5aea27817f095b57f468119a423da3ded39f8f3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 29 Oct 2025 12:02:52 +0800 Subject: [PATCH] Add encoded dataset caching mechanism --- comfy_extras/nodes_dataset.py | 220 ++++++++++++++++++++++++++++++++-- nodes.py | 1 + 2 files changed, 212 insertions(+), 9 deletions(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 1bf5a1320..1aa49a502 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1,5 +1,6 @@ import logging import os +import pickle import numpy as np import torch @@ -50,7 +51,7 @@ class LoadImageDataSetFromFolderNode: RETURN_TYPES = ("IMAGE_LIST",) FUNCTION = "load_images" - CATEGORY = "loaders" + CATEGORY = "dataset" EXPERIMENTAL = True DESCRIPTION = "Loads a batch of images from a directory for training." @@ -77,7 +78,7 @@ class LoadImageTextDataSetFromFolderNode: RETURN_TYPES = ("IMAGE_LIST", "TEXT_LIST",) FUNCTION = "load_images" - CATEGORY = "loaders" + CATEGORY = "dataset" EXPERIMENTAL = True DESCRIPTION = "Loads a batch of images and caption from a directory for training." @@ -115,8 +116,6 @@ class LoadImageTextDataSetFromFolderNode: else: captions.append("") - width = width if width != -1 else None - height = height if height != -1 else None output_tensor = load_and_process_images(image_files, sub_input_dir) logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") @@ -181,7 +180,7 @@ class SaveImageDataSetToFolderNode: RETURN_TYPES = () OUTPUT_NODE = True FUNCTION = "save_images" - CATEGORY = "loaders" + CATEGORY = "dataset" EXPERIMENTAL = True DESCRIPTION = "Saves a batch of images to a directory." @@ -208,7 +207,7 @@ class SaveImageTextDataSetToFolderNode: RETURN_TYPES = () OUTPUT_NODE = True FUNCTION = "save_images" - CATEGORY = "loaders" + CATEGORY = "dataset" EXPERIMENTAL = True DESCRIPTION = "Saves a batch of images and captions to a directory." @@ -232,7 +231,7 @@ class SaveImageTextDataSetToFolderNode: class ImageProcessingNode: """Base class for image processing nodes that operate on IMAGE_LIST.""" - CATEGORY = "image/transforms" + CATEGORY = "dataset/image" EXPERIMENTAL = True RETURN_TYPES = ("IMAGE_LIST",) FUNCTION = "process" @@ -269,7 +268,7 @@ class ImageProcessingNode: class TextProcessingNode: """Base class for text processing nodes that operate on TEXT_LIST.""" - CATEGORY = "text/transforms" + CATEGORY = "dataset/text" EXPERIMENTAL = True RETURN_TYPES = ("TEXT_LIST",) FUNCTION = "process" @@ -518,7 +517,7 @@ class ShuffleDatasetNode(ImageProcessingNode): class ShuffleImageTextDatasetNode: """Special node that shuffles both images and texts together (doesn't inherit from base class).""" - CATEGORY = "image/transforms" + CATEGORY = "dataset/image" EXPERIMENTAL = True RETURN_TYPES = ("IMAGE_LIST", "TEXT_LIST") FUNCTION = "process" @@ -620,6 +619,201 @@ class StripWhitespaceNode(TextProcessingNode): return [text.strip() for text in texts] +# ========== Training Dataset Nodes ========== + +class MakeTrainingDataset: + """Encode images with VAE and texts with CLIP to create a training dataset.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ("IMAGE_LIST", {"tooltip": "List of images to encode."}), + "vae": ("VAE", {"tooltip": "VAE model for encoding images to latents."}), + "clip": ("CLIP", {"tooltip": "CLIP model for encoding text to conditioning."}), + }, + "optional": { + "texts": ("TEXT_LIST", {"tooltip": "List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string)."}), + }, + } + + RETURN_TYPES = ("LATENT_LIST", "CONDITIONING") + RETURN_NAMES = ("latents", "conditioning") + FUNCTION = "make_dataset" + CATEGORY = "dataset" + EXPERIMENTAL = True + DESCRIPTION = "Encodes images with VAE and texts with CLIP to create a training dataset. Returns a list of latents and a flat conditioning list." + + def make_dataset(self, images, vae, clip, texts=None): + # Handle text list + num_images = len(images) + + if texts is None or len(texts) == 0: + # Treat as [""] for unconditional training + texts = [""] + + if len(texts) == 1 and num_images > 1: + # Repeat single text for all images + texts = texts * num_images + elif len(texts) != num_images: + raise ValueError( + f"Number of texts ({len(texts)}) does not match number of images ({num_images}). " + f"Text list should have length {num_images}, 1, or 0." + ) + + # Encode images with VAE + logging.info(f"Encoding {num_images} images with VAE...") + latents = [] + for img_tensor in images: + # img_tensor is [1, H, W, 3] + t = vae.encode(img_tensor[:,:,:,:3]) + latents.append({"samples": t}) + + # Encode texts with CLIP + logging.info(f"Encoding {len(texts)} texts with CLIP...") + conditions = [] + empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + for text in texts: + if text == "": + conditions.extend(empty_cond) + else: + tokens = clip.tokenize(text) + cond = clip.encode_from_tokens_scheduled(tokens) + conditions.extend(cond) + + logging.info(f"Created dataset with {len(latents)} latents and {len(conditions)} conditions.") + return (latents, conditions) + + +class SaveTrainingDataset: + """Save encoded training dataset (latents + conditioning) to disk.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "latents": ("LATENT_LIST", {"tooltip": "List of latent tensors from MakeTrainingDataset."}), + "conditioning": ("CONDITIONING", {"tooltip": "Conditioning list from MakeTrainingDataset."}), + "folder_name": ("STRING", {"default": "training_dataset", "tooltip": "Name of folder to save dataset (inside output directory)."}), + "shard_size": ("INT", {"default": 1000, "min": 1, "max": 100000, "step": 1, "tooltip": "Number of samples per shard file."}), + }, + } + + RETURN_TYPES = () + OUTPUT_NODE = True + FUNCTION = "save_dataset" + CATEGORY = "dataset" + EXPERIMENTAL = True + DESCRIPTION = "Saves a training dataset to disk in sharded pickle files. Each shard contains (latent, conditioning) pairs." + + def save_dataset(self, latents, conditioning, folder_name, shard_size): + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " + f"Something went wrong in dataset preparation." + ) + + # Create output directory + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + os.makedirs(output_dir, exist_ok=True) + + # Prepare data pairs + num_samples = len(latents) + num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division + + logging.info(f"Saving {num_samples} samples to {num_shards} shards in {output_dir}...") + + # Save data in shards + for shard_idx in range(num_shards): + start_idx = shard_idx * shard_size + end_idx = min(start_idx + shard_size, num_samples) + + # Get shard data + shard_data = { + "latents": latents[start_idx:end_idx], + "conditioning": conditioning[start_idx:end_idx], + } + + # Save shard + shard_filename = f"shard_{shard_idx:04d}.pkl" + shard_path = os.path.join(output_dir, shard_filename) + + with open(shard_path, "wb") as f: + pickle.dump(shard_data, f, protocol=pickle.HIGHEST_PROTOCOL) + + logging.info(f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)") + + # Save metadata + metadata = { + "num_samples": num_samples, + "num_shards": num_shards, + "shard_size": shard_size, + } + metadata_path = os.path.join(output_dir, "metadata.json") + with open(metadata_path, "w") as f: + import json + json.dump(metadata, f, indent=2) + + logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") + return {} + + +class LoadTrainingDataset: + """Load encoded training dataset from disk.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "folder_name": ("STRING", {"default": "training_dataset", "tooltip": "Name of folder containing the saved dataset (inside output directory)."}), + }, + } + + RETURN_TYPES = ("LATENT_LIST", "CONDITIONING") + RETURN_NAMES = ("latents", "conditioning") + FUNCTION = "load_dataset" + CATEGORY = "dataset" + EXPERIMENTAL = True + DESCRIPTION = "Loads a training dataset from disk. Returns a list of latents and a flat conditioning list." + + def load_dataset(self, folder_name): + # Get dataset directory + dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + + if not os.path.exists(dataset_dir): + raise ValueError(f"Dataset directory not found: {dataset_dir}") + + # Find all shard files + shard_files = sorted([ + f for f in os.listdir(dataset_dir) + if f.startswith("shard_") and f.endswith(".pkl") + ]) + + if not shard_files: + raise ValueError(f"No shard files found in {dataset_dir}") + + logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") + + # Load all shards + all_latents = [] + all_conditioning = [] + + for shard_file in shard_files: + shard_path = os.path.join(dataset_dir, shard_file) + + with open(shard_path, "rb") as f: + shard_data = pickle.load(f) + + all_latents.extend(shard_data["latents"]) + all_conditioning.extend(shard_data["conditioning"]) + + logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") + + logging.info(f"Successfully loaded {len(all_latents)} samples from {dataset_dir}.") + return (all_latents, all_conditioning) + + NODE_CLASS_MAPPINGS = { "LoadImageDataSetFromFolderNode": LoadImageDataSetFromFolderNode, "LoadImageTextDataSetFromFolderNode": LoadImageTextDataSetFromFolderNode, @@ -645,6 +839,10 @@ NODE_CLASS_MAPPINGS = { "AddTextSuffixNode": AddTextSuffixNode, "ReplaceTextNode": ReplaceTextNode, "StripWhitespaceNode": StripWhitespaceNode, + # Training dataset nodes + "MakeTrainingDataset": MakeTrainingDataset, + "SaveTrainingDataset": SaveTrainingDataset, + "LoadTrainingDataset": LoadTrainingDataset, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -672,4 +870,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "AddTextSuffixNode": "Add Text Suffix", "ReplaceTextNode": "Replace Text", "StripWhitespaceNode": "Strip Whitespace", + # Training dataset nodes + "MakeTrainingDataset": "Make Training Dataset", + "SaveTrainingDataset": "Save Training Dataset", + "LoadTrainingDataset": "Load Training Dataset", } diff --git a/nodes.py b/nodes.py index 88d712993..89a96fee2 100644 --- a/nodes.py +++ b/nodes.py @@ -2275,6 +2275,7 @@ async def init_builtin_extra_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_train.py", + "nodes_dataset.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py",