diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 941dc02c9..820d8a4e9 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -6,9 +6,11 @@ import math import numpy as np import torch from PIL import Image +from typing_extensions import override import folder_paths import node_helpers +from comfy_api.latest import ComfyExtension, io def load_and_process_images(image_files, input_dir): @@ -41,22 +43,28 @@ def load_and_process_images(image_files, input_dir): return output_images -class LoadImageDataSetFromFolderNode: +class LoadImageDataSetFromFolderNode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) - }, - } + def define_schema(cls): + return io.Schema( + node_id="LoadImageDataSetFromFolder", + display_name="Load Simple Image Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output(is_output_list=True, tooltip="List of loaded images") + ], + ) - RETURN_TYPES = ("IMAGE_LIST",) - FUNCTION = "load_images" - CATEGORY = "dataset" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - def load_images(self, folder): + @classmethod + def execute(cls, folder): sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] image_files = [ @@ -65,25 +73,32 @@ class LoadImageDataSetFromFolderNode: if any(f.lower().endswith(ext) for ext in valid_extensions) ] output_tensor = load_and_process_images(image_files, sub_input_dir) - return (output_tensor,) + return io.NodeOutput(output_tensor) -class LoadImageTextDataSetFromFolderNode: +class LoadImageTextDataSetFromFolderNode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}), - }, - } + def define_schema(cls): + return io.Schema( + node_id="LoadImageTextDataSetFromFolder", + display_name="Load Simple Image and Text Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output(is_output_list=True, tooltip="List of loaded images"), + io.String.Output(is_output_list=True, tooltip="List of text captions"), + ], + ) - RETURN_TYPES = ("IMAGE_LIST", "TEXT_LIST",) - FUNCTION = "load_images" - CATEGORY = "dataset" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images and caption from a directory for training." - - def load_images(self, folder): + @classmethod + def execute(cls, folder): logging.info(f"Loading images from folder: {folder}") sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) @@ -99,13 +114,17 @@ class LoadImageTextDataSetFromFolderNode: repeat = 1 if item.split("_")[0].isdigit(): repeat = int(item.split("_")[0]) - image_files.extend([ - os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions) - ] * repeat) + image_files.extend( + [ + os.path.join(path, f) + for f in os.listdir(path) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + * repeat + ) caption_file_path = [ - f.replace(os.path.splitext(f)[1], ".txt") - for f in image_files + f.replace(os.path.splitext(f)[1], ".txt") for f in image_files ] captions = [] for caption_file in caption_file_path: @@ -120,7 +139,7 @@ class LoadImageTextDataSetFromFolderNode: output_tensor = load_and_process_images(image_files, sub_input_dir) logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") - return (output_tensor, captions) + return io.NodeOutput(output_tensor, captions) def save_images_to_folder(image_list, output_dir, prefix="image"): @@ -146,7 +165,11 @@ def save_images_to_folder(image_list, output_dir, prefix="image"): # If tensor is [C, H, W], permute to [H, W, C] if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]: - if img_tensor.shape[0] <= 4 and img_tensor.shape[1] > 4 and img_tensor.shape[2] > 4: + if ( + img_tensor.shape[0] <= 4 + and img_tensor.shape[1] > 4 + and img_tensor.shape[2] > 4 + ): img_tensor = img_tensor.permute(1, 2, 0) # Convert to numpy and scale to 0-255 @@ -167,52 +190,78 @@ def save_images_to_folder(image_list, output_dir, prefix="image"): return saved_files -class SaveImageDataSetToFolderNode: +class SaveImageDataSetToFolderNode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "images": ("IMAGE_LIST", {"tooltip": "List of images to save."}), - "folder_name": ("STRING", {"default": "dataset", "tooltip": "Name of the folder to save images to (inside output directory)."}), - "filename_prefix": ("STRING", {"default": "image", "tooltip": "Prefix for saved image filenames."}), - }, - } + def define_schema(cls): + return io.Schema( + node_id="SaveImageDataSetToFolder", + display_name="Save Simple Image Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive images as list + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) - RETURN_TYPES = () - OUTPUT_NODE = True - FUNCTION = "save_images" - CATEGORY = "dataset" - EXPERIMENTAL = True - DESCRIPTION = "Saves a batch of images to a directory." + @classmethod + def execute(cls, images, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] - def save_images(self, images, folder_name, filename_prefix): output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) saved_files = save_images_to_folder(images, output_dir, filename_prefix) logging.info(f"Saved {len(saved_files)} images to {output_dir}.") - return {} + return io.NodeOutput() -class SaveImageTextDataSetToFolderNode: +class SaveImageTextDataSetToFolderNode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "images": ("IMAGE_LIST", {"tooltip": "List of images to save."}), - "texts": ("TEXT_LIST", {"tooltip": "List of text captions to save."}), - "folder_name": ("STRING", {"default": "dataset", "tooltip": "Name of the folder to save images to (inside output directory)."}), - "filename_prefix": ("STRING", {"default": "image", "tooltip": "Prefix for saved image filenames."}), - }, - } + def define_schema(cls): + return io.Schema( + node_id="SaveImageTextDataSetToFolder", + display_name="Save Simple Image and Text Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive both images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input("texts", tooltip="List of text captions to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) - RETURN_TYPES = () - OUTPUT_NODE = True - FUNCTION = "save_images" - CATEGORY = "dataset" - EXPERIMENTAL = True - DESCRIPTION = "Saves a batch of images and captions to a directory." + @classmethod + def execute(cls, images, texts, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] - def save_images(self, images, texts, folder_name, filename_prefix): output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) saved_files = save_images_to_folder(images, output_dir, filename_prefix) @@ -224,92 +273,180 @@ class SaveImageTextDataSetToFolderNode: f.write(caption) logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") - return {} + return io.NodeOutput() + + +# ========== Helper Functions for Transform Nodes ========== + + +def tensor_to_pil(img_tensor): + """Convert tensor to PIL Image.""" + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + return Image.fromarray(img_array) + + +def pil_to_tensor(img): + """Convert PIL Image to tensor.""" + img_array = np.array(img).astype(np.float32) / 255.0 + return torch.from_numpy(img_array)[None,] # ========== Base Classes for Transform Nodes ========== -class ImageProcessingNode: - """Base class for image processing nodes that operate on IMAGE_LIST.""" - CATEGORY = "dataset/image" - EXPERIMENTAL = True - RETURN_TYPES = ("IMAGE_LIST",) - FUNCTION = "process" +class ImageProcessingNode(io.ComfyNode): + """Base class for image processing nodes that operate on lists of images. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "images" (optional) + + Child classes must implement: + _process(cls, images, **kwargs) -> list[tensor] + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "images": ("IMAGE_LIST", {"tooltip": "List of images to process."}), - }, - } + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") - def process(self, images, **kwargs): - """Default process function that calls _process for each image.""" - return (self._process(images, **kwargs),) + inputs = [io.Image.Input("images", tooltip="List of images to process.")] + inputs.extend(cls.extra_inputs) - def _process(self, images, **kwargs): - """Override this method in subclasses to implement specific processing.""" - raise NotImplementedError("Subclasses must implement _process method") - - def _tensor_to_pil(self, img_tensor): - """Convert tensor to PIL Image.""" - if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: - img_tensor = img_tensor.squeeze(0) - img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - return Image.fromarray(img_array) - - def _pil_to_tensor(self, img): - """Convert PIL Image to tensor.""" - img_array = np.array(img).astype(np.float32) / 255.0 - return torch.from_numpy(img_array)[None,] - - -class TextProcessingNode: - """Base class for text processing nodes that operate on TEXT_LIST.""" - - CATEGORY = "dataset/text" - EXPERIMENTAL = True - RETURN_TYPES = ("TEXT_LIST",) - FUNCTION = "process" + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/image", + is_experimental=True, + is_input_list=True, + inputs=inputs, + outputs=[io.Image.Output(is_output_list=True, tooltip="Processed images")], + ) @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "texts": ("TEXT_LIST", {"tooltip": "List of texts to process."}), - }, - } + def execute(cls, images, **kwargs): + """Execute the node. Extracts scalar values and calls _process.""" + # Extract scalar values from lists (due to is_input_list=True) + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v - def process(self, texts, **kwargs): - """Default process function that calls _process.""" - return (self._process(texts, **kwargs),) + result = cls._process(images, **params) + return io.NodeOutput(result) - def _process(self, texts, **kwargs): - """Override this method in subclasses to implement specific processing.""" - raise NotImplementedError("Subclasses must implement _process method") + @classmethod + def _process(cls, images, **kwargs): + """Override this method in subclasses to implement specific processing. + + Args: + images: list[tensor] - List of image tensors + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[tensor] - Processed images + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + +class TextProcessingNode(io.ComfyNode): + """Base class for text processing nodes that operate on lists of texts. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "texts" (optional) + + Child classes must implement: + _process(cls, texts, **kwargs) -> list[str] + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + inputs = [io.String.Input("texts", tooltip="List of texts to process.")] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/text", + is_experimental=True, + is_input_list=True, + inputs=inputs, + outputs=[io.String.Output(is_output_list=True, tooltip="Processed texts")], + ) + + @classmethod + def execute(cls, texts, **kwargs): + """Execute the node. Extracts scalar values and calls _process.""" + # Extract scalar values from lists (due to is_input_list=True) + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + result = cls._process(texts, **params) + return io.NodeOutput(result) + + @classmethod + def _process(cls, texts, **kwargs): + """Override this method in subclasses to implement specific processing. + + Args: + texts: list[str] - List of text strings + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[str] - Processed texts + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") # ========== Image Transform Nodes ========== + class ResizeImagesToSameSizeNode(ImageProcessingNode): - DESCRIPTION = "Resize all images to the same width and height." + node_id = "ResizeImagesToSameSize" + display_name = "Resize Images to Same Size" + description = "Resize all images to the same width and height." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."), + io.Combo.Input( + "mode", + options=["stretch", "crop_center", "pad"], + default="stretch", + tooltip="Resize mode.", + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"].update({ - "width": ("INT", {"default": 512, "min": 1, "max": 8192, "step": 1, "tooltip": "Target width."}), - "height": ("INT", {"default": 512, "min": 1, "max": 8192, "step": 1, "tooltip": "Target height."}), - "mode": (["stretch", "crop_center", "pad"], {"default": "stretch", "tooltip": "Resize mode."}), - }) - return base_inputs - - def _process(self, images, width, height, mode): + def _process(cls, images, width, height, mode): output_images = [] for img_tensor in images: - img = self._tensor_to_pil(img_tensor) + img = tensor_to_pil(img_tensor) if mode == "stretch": img = img.resize((width, height), Image.Resampling.LANCZOS) @@ -329,47 +466,65 @@ class ResizeImagesToSameSizeNode(ImageProcessingNode): new_img.paste(img, (paste_x, paste_y)) img = new_img - output_images.append(self._pil_to_tensor(img)) + output_images.append(pil_to_tensor(img)) return output_images class ResizeImagesToPixelCountNode(ImageProcessingNode): - DESCRIPTION = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." + node_id = "ResizeImagesToPixelCount" + display_name = "Resize Images to Pixel Count" + description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "pixel_count", + default=512 * 512, + min=1, + max=8192 * 8192, + tooltip="Target pixel count.", + ), + io.Int.Input( + "steps", + default=64, + min=1, + max=128, + tooltip="The stepping for resize width/height.", + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"]["pixel_count"] = ("INT", {"default": 512 * 512, "min": 1, "max": 8192 * 8192, "step": 1, "tooltip": "Target pixel count."}) - base_inputs["required"]["steps"] = ("INT", {"default": 64, "min": 1, "max": 128, "step": 1, "tooltip": "The stepping for resize width/height."}) - return base_inputs - - def _process(self, images, pixel_count, steps): + def _process(cls, images, pixel_count, steps): output_images = [] for img_tensor in images: - img = self._tensor_to_pil(img_tensor) + img = tensor_to_pil(img_tensor) w, h = img.size pixel_count_ratio = math.sqrt(pixel_count / (w * h)) new_w = int(h * pixel_count_ratio / steps) * steps new_h = int(w * pixel_count_ratio / steps) * steps logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - output_images.append(self._pil_to_tensor(img)) + output_images.append(pil_to_tensor(img)) return output_images class ResizeImagesByShorterEdgeNode(ImageProcessingNode): - DESCRIPTION = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." + node_id = "ResizeImagesByShorterEdge" + display_name = "Resize Images by Shorter Edge" + description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "shorter_edge", + default=512, + min=1, + max=8192, + tooltip="Target length for the shorter edge.", + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"]["shorter_edge"] = ("INT", {"default": 512, "min": 1, "max": 8192, "step": 1, "tooltip": "Target length for the shorter edge."}) - return base_inputs - - def _process(self, images, shorter_edge): + def _process(cls, images, shorter_edge): output_images = [] for img_tensor in images: - img = self._tensor_to_pil(img_tensor) + img = tensor_to_pil(img_tensor) w, h = img.size if w < h: new_w = shorter_edge @@ -378,23 +533,29 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode): new_h = shorter_edge new_w = int(w * (shorter_edge / h)) img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - output_images.append(self._pil_to_tensor(img)) + output_images.append(pil_to_tensor(img)) return output_images class ResizeImagesByLongerEdgeNode(ImageProcessingNode): - DESCRIPTION = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." + node_id = "ResizeImagesByLongerEdge" + display_name = "Resize Images by Longer Edge" + description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "longer_edge", + default=1024, + min=1, + max=8192, + tooltip="Target length for the longer edge.", + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"]["longer_edge"] = ("INT", {"default": 1024, "min": 1, "max": 8192, "step": 1, "tooltip": "Target length for the longer edge."}) - return base_inputs - - def _process(self, images, longer_edge): + def _process(cls, images, longer_edge): output_images = [] for img_tensor in images: - img = self._tensor_to_pil(img_tensor) + img = tensor_to_pil(img_tensor) w, h = img.size if w > h: new_w = longer_edge @@ -403,53 +564,53 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode): new_h = longer_edge new_w = int(w * (longer_edge / h)) img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - output_images.append(self._pil_to_tensor(img)) + output_images.append(pil_to_tensor(img)) return output_images class CenterCropImagesNode(ImageProcessingNode): - DESCRIPTION = "Center crop all images to the specified dimensions." + node_id = "CenterCropImages" + display_name = "Center Crop Images" + description = "Center crop all images to the specified dimensions." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"].update({ - "width": ("INT", {"default": 512, "min": 1, "max": 8192, "step": 1, "tooltip": "Crop width."}), - "height": ("INT", {"default": 512, "min": 1, "max": 8192, "step": 1, "tooltip": "Crop height."}), - }) - return base_inputs - - def _process(self, images, width, height): + def _process(cls, images, width, height): output_images = [] for img_tensor in images: - img = self._tensor_to_pil(img_tensor) + img = tensor_to_pil(img_tensor) left = max(0, (img.width - width) // 2) top = max(0, (img.height - height) // 2) right = min(img.width, left + width) bottom = min(img.height, top + height) img = img.crop((left, top, right, bottom)) - output_images.append(self._pil_to_tensor(img)) + output_images.append(pil_to_tensor(img)) return output_images class RandomCropImagesNode(ImageProcessingNode): - DESCRIPTION = "Randomly crop all images to the specified dimensions (for data augmentation)." + node_id = "RandomCropImages" + display_name = "Random Crop Images" + description = ( + "Randomly crop all images to the specified dimensions (for data augmentation)." + ) + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"].update({ - "width": ("INT", {"default": 512, "min": 1, "max": 8192, "step": 1, "tooltip": "Crop width."}), - "height": ("INT", {"default": 512, "min": 1, "max": 8192, "step": 1, "tooltip": "Crop height."}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "tooltip": "Random seed."}), - }) - return base_inputs - - def _process(self, images, width, height, seed): - np.random.seed(seed%(2**32-1)) + def _process(cls, images, width, height, seed): + np.random.seed(seed % (2**32 - 1)) output_images = [] for img_tensor in images: - img = self._tensor_to_pil(img_tensor) + img = tensor_to_pil(img_tensor) max_left = max(0, img.width - width) max_top = max(0, img.height - height) left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 @@ -457,219 +618,285 @@ class RandomCropImagesNode(ImageProcessingNode): right = min(img.width, left + width) bottom = min(img.height, top + height) img = img.crop((left, top, right, bottom)) - output_images.append(self._pil_to_tensor(img)) + output_images.append(pil_to_tensor(img)) return output_images class FlipImagesNode(ImageProcessingNode): - DESCRIPTION = "Flip all images horizontally or vertically." + node_id = "FlipImages" + display_name = "Flip Images" + description = "Flip all images horizontally or vertically." + extra_inputs = [ + io.Combo.Input( + "direction", + options=["horizontal", "vertical"], + default="horizontal", + tooltip="Flip direction.", + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"]["direction"] = (["horizontal", "vertical"], {"default": "horizontal", "tooltip": "Flip direction."}) - return base_inputs - - def _process(self, images, direction): + def _process(cls, images, direction): output_images = [] for img_tensor in images: - img = self._tensor_to_pil(img_tensor) + img = tensor_to_pil(img_tensor) if direction == "horizontal": img = img.transpose(Image.FLIP_LEFT_RIGHT) else: img = img.transpose(Image.FLIP_TOP_BOTTOM) - output_images.append(self._pil_to_tensor(img)) + output_images.append(pil_to_tensor(img)) return output_images class NormalizeImagesNode(ImageProcessingNode): - DESCRIPTION = "Normalize images using mean and standard deviation." + node_id = "NormalizeImages" + display_name = "Normalize Images" + description = "Normalize images using mean and standard deviation." + extra_inputs = [ + io.Float.Input( + "mean", + default=0.5, + min=0.0, + max=1.0, + tooltip="Mean value for normalization.", + ), + io.Float.Input( + "std", + default=0.5, + min=0.001, + max=1.0, + tooltip="Standard deviation for normalization.", + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"].update({ - "mean": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Mean value for normalization."}), - "std": ("FLOAT", {"default": 0.5, "min": 0.001, "max": 1.0, "step": 0.01, "tooltip": "Standard deviation for normalization."}), - }) - return base_inputs - - def _process(self, images, mean, std): + def _process(cls, images, mean, std): return [(img - mean) / std for img in images] class AdjustBrightnessNode(ImageProcessingNode): - DESCRIPTION = "Adjust brightness of all images." + node_id = "AdjustBrightness" + display_name = "Adjust Brightness" + description = "Adjust brightness of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.", + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"]["factor"] = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.01, "tooltip": "Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter."}) - return base_inputs - - def _process(self, images, factor): + def _process(cls, images, factor): return [(img * factor).clamp(0.0, 1.0) for img in images] class AdjustContrastNode(ImageProcessingNode): - DESCRIPTION = "Adjust contrast of all images." + node_id = "AdjustContrast" + display_name = "Adjust Contrast" + description = "Adjust contrast of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.", + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"]["factor"] = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.01, "tooltip": "Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast."}) - return base_inputs - - def _process(self, images, factor): + def _process(cls, images, factor): return [((img - 0.5) * factor + 0.5).clamp(0.0, 1.0) for img in images] class ShuffleDatasetNode(ImageProcessingNode): - DESCRIPTION = "Randomly shuffle the order of images in the dataset." + node_id = "ShuffleDataset" + display_name = "Shuffle Image Dataset" + description = "Randomly shuffle the order of images in the dataset." + extra_inputs = [ + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"]["seed"] = ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "tooltip": "Random seed."}) - return base_inputs - - def _process(self, images, seed): - np.random.seed(seed%(2**32-1)) + def _process(cls, images, seed): + np.random.seed(seed % (2**32 - 1)) indices = np.random.permutation(len(images)) return [images[i] for i in indices] -class ShuffleImageTextDatasetNode: - """Special node that shuffles both images and texts together (doesn't inherit from base class).""" - - CATEGORY = "dataset/image" - EXPERIMENTAL = True - RETURN_TYPES = ("IMAGE_LIST", "TEXT_LIST") - FUNCTION = "process" - DESCRIPTION = "Randomly shuffle the order of images and texts in the dataset together." +class ShuffleImageTextDatasetNode(io.ComfyNode): + """Special node that shuffles both images and texts together.""" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "images": ("IMAGE_LIST", {"tooltip": "List of images to shuffle."}), - "texts": ("TEXT_LIST", {"tooltip": "List of texts to shuffle."}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "tooltip": "Random seed."}), - }, - } + def define_schema(cls): + return io.Schema( + node_id="ShuffleImageTextDataset", + display_name="Shuffle Image-Text Dataset", + category="dataset/image", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Image.Input("images", tooltip="List of images to shuffle."), + io.String.Input("texts", tooltip="List of texts to shuffle."), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed.", + ), + ], + outputs=[ + io.Image.Output(is_output_list=True, tooltip="Shuffled images"), + io.String.Output(is_output_list=True, tooltip="Shuffled texts"), + ], + ) - def process(self, images, texts, seed): - np.random.seed(seed%(2**32-1)) + @classmethod + def execute(cls, images, texts, seed): + seed = seed[0] # Extract scalar + np.random.seed(seed % (2**32 - 1)) indices = np.random.permutation(len(images)) shuffled_images = [images[i] for i in indices] shuffled_texts = [texts[i] for i in indices] - return (shuffled_images, shuffled_texts) + return io.NodeOutput(shuffled_images, shuffled_texts) # ========== Text Transform Nodes ========== -class TextToLowercaseNode(TextProcessingNode): - DESCRIPTION = "Convert all texts to lowercase." - def _process(self, texts): +class TextToLowercaseNode(TextProcessingNode): + node_id = "TextToLowercase" + display_name = "Text to Lowercase" + description = "Convert all texts to lowercase." + + @classmethod + def _process(cls, texts): return [text.lower() for text in texts] class TextToUppercaseNode(TextProcessingNode): - DESCRIPTION = "Convert all texts to uppercase." + node_id = "TextToUppercase" + display_name = "Text to Uppercase" + description = "Convert all texts to uppercase." - def _process(self, texts): + @classmethod + def _process(cls, texts): return [text.upper() for text in texts] class TruncateTextNode(TextProcessingNode): - DESCRIPTION = "Truncate all texts to a maximum length." + node_id = "TruncateText" + display_name = "Truncate Text" + description = "Truncate all texts to a maximum length." + extra_inputs = [ + io.Int.Input( + "max_length", default=77, min=1, max=10000, tooltip="Maximum text length." + ), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"]["max_length"] = ("INT", {"default": 77, "min": 1, "max": 10000, "step": 1, "tooltip": "Maximum text length."}) - return base_inputs - - def _process(self, texts, max_length): + def _process(cls, texts, max_length): return [text[:max_length] for text in texts] class AddTextPrefixNode(TextProcessingNode): - DESCRIPTION = "Add a prefix to all texts." + node_id = "AddTextPrefix" + display_name = "Add Text Prefix" + description = "Add a prefix to all texts." + extra_inputs = [ + io.String.Input("prefix", default="", tooltip="Prefix to add."), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"]["prefix"] = ("STRING", {"default": "", "multiline": False, "tooltip": "Prefix to add."}) - return base_inputs - - def _process(self, texts, prefix): + def _process(cls, texts, prefix): return [prefix + text for text in texts] class AddTextSuffixNode(TextProcessingNode): - DESCRIPTION = "Add a suffix to all texts." + node_id = "AddTextSuffix" + display_name = "Add Text Suffix" + description = "Add a suffix to all texts." + extra_inputs = [ + io.String.Input("suffix", default="", tooltip="Suffix to add."), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"]["suffix"] = ("STRING", {"default": "", "multiline": False, "tooltip": "Suffix to add."}) - return base_inputs - - def _process(self, texts, suffix): + def _process(cls, texts, suffix): return [text + suffix for text in texts] class ReplaceTextNode(TextProcessingNode): - DESCRIPTION = "Replace text in all texts." + node_id = "ReplaceText" + display_name = "Replace Text" + description = "Replace text in all texts." + extra_inputs = [ + io.String.Input("find", default="", tooltip="Text to find."), + io.String.Input("replace", default="", tooltip="Text to replace with."), + ] @classmethod - def INPUT_TYPES(cls): - base_inputs = super().INPUT_TYPES() - base_inputs["required"].update({ - "find": ("STRING", {"default": "", "multiline": False, "tooltip": "Text to find."}), - "replace": ("STRING", {"default": "", "multiline": False, "tooltip": "Text to replace with."}), - }) - return base_inputs - - def _process(self, texts, find, replace): + def _process(cls, texts, find, replace): return [text.replace(find, replace) for text in texts] class StripWhitespaceNode(TextProcessingNode): - DESCRIPTION = "Strip leading and trailing whitespace from all texts." + node_id = "StripWhitespace" + display_name = "Strip Whitespace" + description = "Strip leading and trailing whitespace from all texts." - def _process(self, texts): + @classmethod + def _process(cls, texts): return [text.strip() for text in texts] # ========== Training Dataset Nodes ========== -class MakeTrainingDataset: + +class MakeTrainingDataset(io.ComfyNode): """Encode images with VAE and texts with CLIP to create a training dataset.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "images": ("IMAGE_LIST", {"tooltip": "List of images to encode."}), - "vae": ("VAE", {"tooltip": "VAE model for encoding images to latents."}), - "clip": ("CLIP", {"tooltip": "CLIP model for encoding text to conditioning."}), - }, - "optional": { - "texts": ("TEXT_LIST", {"tooltip": "List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string)."}), - }, - } + def define_schema(cls): + return io.Schema( + node_id="MakeTrainingDataset", + display_name="Make Training Dataset", + category="dataset", + is_experimental=True, + is_input_list=True, # images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to encode."), + io.Vae.Input( + "vae", tooltip="VAE model for encoding images to latents." + ), + io.Clip.Input( + "clip", tooltip="CLIP model for encoding text to conditioning." + ), + io.String.Input( + "texts", + optional=True, + tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).", + ), + ], + outputs=[ + io.Latent.Output(is_output_list=True, tooltip="List of latent dicts"), + io.Conditioning.Output( + is_output_list=True, tooltip="List of conditioning lists" + ), + ], + ) - RETURN_TYPES = ("LATENT", "CONDITIONING") - RETURN_NAMES = ("latents", "conditioning") - FUNCTION = "make_dataset" - CATEGORY = "dataset" - EXPERIMENTAL = True - DESCRIPTION = "Encodes images with VAE and texts with CLIP to create a training dataset. Returns a list of latents and a flat conditioning list." + @classmethod + def execute(cls, images, vae, clip, texts=None): + # Extract scalars (vae and clip are single values wrapped in lists) + vae = vae[0] + clip = clip[0] - def make_dataset(self, images, vae, clip, texts=None): # Handle text list num_images = len(images) @@ -688,53 +915,77 @@ class MakeTrainingDataset: # Encode images with VAE logging.info(f"Encoding {num_images} images with VAE...") - latents = [] + latents_list = [] # list[{"samples": tensor}] for img_tensor in images: # img_tensor is [1, H, W, 3] - t = vae.encode(img_tensor[:,:,:,:3]) - latents.append(t) - latents = {"samples": latents} + latent_tensor = vae.encode(img_tensor[:, :, :, :3]) + latents_list.append({"samples": latent_tensor}) # Encode texts with CLIP logging.info(f"Encoding {len(texts)} texts with CLIP...") - conditions = [] - empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + conditioning_list = [] # list[list[cond]] for text in texts: if text == "": - conditions.extend(empty_cond) + cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) else: tokens = clip.tokenize(text) cond = clip.encode_from_tokens_scheduled(tokens) - conditions.extend(cond) + conditioning_list.append(cond) - logging.info(f"Created dataset with {len(latents['samples'])} latents and {len(conditions)} conditions.") - return (latents, conditions) + logging.info( + f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning." + ) + return io.NodeOutput(latents_list, conditioning_list) -class SaveTrainingDataset: +class SaveTrainingDataset(io.ComfyNode): """Save encoded training dataset (latents + conditioning) to disk.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "latents": ("LATENT", {"tooltip": "List of latent tensors from MakeTrainingDataset."}), - "conditioning": ("CONDITIONING", {"tooltip": "Conditioning list from MakeTrainingDataset."}), - "folder_name": ("STRING", {"default": "training_dataset", "tooltip": "Name of folder to save dataset (inside output directory)."}), - "shard_size": ("INT", {"default": 1000, "min": 1, "max": 100000, "step": 1, "tooltip": "Number of samples per shard file."}), - }, - } + def define_schema(cls): + return io.Schema( + node_id="SaveTrainingDataset", + display_name="Save Training Dataset", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive lists + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts from MakeTrainingDataset.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists from MakeTrainingDataset.", + ), + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder to save dataset (inside output directory).", + ), + io.Int.Input( + "shard_size", + default=1000, + min=1, + max=100000, + tooltip="Number of samples per shard file.", + ), + ], + outputs=[], + ) - RETURN_TYPES = () - OUTPUT_NODE = True - FUNCTION = "save_dataset" - CATEGORY = "dataset" - EXPERIMENTAL = True - DESCRIPTION = "Saves a training dataset to disk in sharded pickle files. Each shard contains (latent, conditioning) pairs." + @classmethod + def execute(cls, latents, conditioning, folder_name, shard_size): + # Extract scalars + folder_name = folder_name[0] + shard_size = shard_size[0] + + # latents: list[{"samples": tensor}] + # conditioning: list[list[cond]] - def save_dataset(self, latents, conditioning, folder_name, shard_size): # Validate lengths match - if len(latents["samples"]) != len(conditioning): + if len(latents) != len(conditioning): raise ValueError( f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " f"Something went wrong in dataset preparation." @@ -745,19 +996,21 @@ class SaveTrainingDataset: os.makedirs(output_dir, exist_ok=True) # Prepare data pairs - num_samples = len(latents["samples"]) + num_samples = len(latents) num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division - logging.info(f"Saving {num_samples} samples to {num_shards} shards in {output_dir}...") + logging.info( + f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." + ) # Save data in shards for shard_idx in range(num_shards): start_idx = shard_idx * shard_size end_idx = min(start_idx + shard_size, num_samples) - # Get shard data + # Get shard data (list of latent dicts and conditioning lists) shard_data = { - "latents": latents["samples"][start_idx:end_idx], + "latents": latents[start_idx:end_idx], "conditioning": conditioning[start_idx:end_idx], } @@ -768,7 +1021,9 @@ class SaveTrainingDataset: with open(shard_path, "wb") as f: pickle.dump(shard_data, f, protocol=pickle.HIGHEST_PROTOCOL) - logging.info(f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)") + logging.info( + f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" + ) # Save metadata metadata = { @@ -779,31 +1034,40 @@ class SaveTrainingDataset: metadata_path = os.path.join(output_dir, "metadata.json") with open(metadata_path, "w") as f: import json + json.dump(metadata, f, indent=2) logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") - return {} + return io.NodeOutput() -class LoadTrainingDataset: +class LoadTrainingDataset(io.ComfyNode): """Load encoded training dataset from disk.""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder_name": ("STRING", {"default": "training_dataset", "tooltip": "Name of folder containing the saved dataset (inside output directory)."}), - }, - } + def define_schema(cls): + return io.Schema( + node_id="LoadTrainingDataset", + display_name="Load Training Dataset", + category="dataset", + is_experimental=True, + inputs=[ + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder containing the saved dataset (inside output directory).", + ), + ], + outputs=[ + io.Latent.Output(is_output_list=True, tooltip="List of latent dicts"), + io.Conditioning.Output( + is_output_list=True, tooltip="List of conditioning lists" + ), + ], + ) - RETURN_TYPES = ("LATENT", "CONDITIONING") - RETURN_NAMES = ("latents", "conditioning") - FUNCTION = "load_dataset" - CATEGORY = "dataset" - EXPERIMENTAL = True - DESCRIPTION = "Loads a training dataset from disk. Returns a list of latents and a flat conditioning list." - - def load_dataset(self, folder_name): + @classmethod + def execute(cls, folder_name): # Get dataset directory dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) @@ -811,10 +1075,13 @@ class LoadTrainingDataset: raise ValueError(f"Dataset directory not found: {dataset_dir}") # Find all shard files - shard_files = sorted([ - f for f in os.listdir(dataset_dir) - if f.startswith("shard_") and f.endswith(".pkl") - ]) + shard_files = sorted( + [ + f + for f in os.listdir(dataset_dir) + if f.startswith("shard_") and f.endswith(".pkl") + ] + ) if not shard_files: raise ValueError(f"No shard files found in {dataset_dir}") @@ -822,8 +1089,8 @@ class LoadTrainingDataset: logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") # Load all shards - all_latents = [] - all_conditioning = [] + all_latents = [] # list[{"samples": tensor}] + all_conditioning = [] # list[list[cond]] for shard_file in shard_files: shard_path = os.path.join(dataset_dir, shard_file) @@ -836,70 +1103,51 @@ class LoadTrainingDataset: logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") - logging.info(f"Successfully loaded {len(all_latents)} samples from {dataset_dir}.") - return ({"samples": all_latents}, all_conditioning) + logging.info( + f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." + ) + return io.NodeOutput(all_latents, all_conditioning) -NODE_CLASS_MAPPINGS = { - "LoadImageDataSetFromFolderNode": LoadImageDataSetFromFolderNode, - "LoadImageTextDataSetFromFolderNode": LoadImageTextDataSetFromFolderNode, - "SaveImageDataSetToFolderNode": SaveImageDataSetToFolderNode, - "SaveImageTextDataSetToFolderNode": SaveImageTextDataSetToFolderNode, - # Image transforms - "ResizeImagesToSameSizeNode": ResizeImagesToSameSizeNode, - "ResizeImagesToPixelCountNode": ResizeImagesToPixelCountNode, - "ResizeImagesByShorterEdgeNode": ResizeImagesByShorterEdgeNode, - "ResizeImagesByLongerEdgeNode": ResizeImagesByLongerEdgeNode, - "CenterCropImagesNode": CenterCropImagesNode, - "RandomCropImagesNode": RandomCropImagesNode, - "FlipImagesNode": FlipImagesNode, - "NormalizeImagesNode": NormalizeImagesNode, - "AdjustBrightnessNode": AdjustBrightnessNode, - "AdjustContrastNode": AdjustContrastNode, - "ShuffleDatasetNode": ShuffleDatasetNode, - "ShuffleImageTextDatasetNode": ShuffleImageTextDatasetNode, - # Text transforms - "TextToLowercaseNode": TextToLowercaseNode, - "TextToUppercaseNode": TextToUppercaseNode, - "TruncateTextNode": TruncateTextNode, - "AddTextPrefixNode": AddTextPrefixNode, - "AddTextSuffixNode": AddTextSuffixNode, - "ReplaceTextNode": ReplaceTextNode, - "StripWhitespaceNode": StripWhitespaceNode, - # Training dataset nodes - "MakeTrainingDataset": MakeTrainingDataset, - "SaveTrainingDataset": SaveTrainingDataset, - "LoadTrainingDataset": LoadTrainingDataset, -} +# ========== Extension Setup ========== -NODE_DISPLAY_NAME_MAPPINGS = { - "LoadImageDataSetFromFolderNode": "Load Simple Image Dataset from Folder", - "LoadImageTextDataSetFromFolderNode": "Load Simple Image and Text Dataset from Folder", - "SaveImageDataSetToFolderNode": "Save Simple Image Dataset to Folder", - "SaveImageTextDataSetToFolderNode": "Save Simple Image and Text Dataset to Folder", - # Image transforms - "ResizeImagesToSameSizeNode": "Resize Images to Same Size", - "ResizeImagesToPixelCountNode": "Resize Images to Pixel Count", - "ResizeImagesByShorterEdgeNode": "Resize Images by Shorter Edge", - "ResizeImagesByLongerEdgeNode": "Resize Images by Longer Edge", - "CenterCropImagesNode": "Center Crop Images", - "RandomCropImagesNode": "Random Crop Images", - "FlipImagesNode": "Flip Images", - "NormalizeImagesNode": "Normalize Images", - "AdjustBrightnessNode": "Adjust Brightness", - "AdjustContrastNode": "Adjust Contrast", - "ShuffleDatasetNode": "Shuffle Image Dataset", - "ShuffleImageTextDatasetNode": "Shuffle Image-Text Dataset", - # Text transforms - "TextToLowercaseNode": "Text to Lowercase", - "TextToUppercaseNode": "Text to Uppercase", - "TruncateTextNode": "Truncate Text", - "AddTextPrefixNode": "Add Text Prefix", - "AddTextSuffixNode": "Add Text Suffix", - "ReplaceTextNode": "Replace Text", - "StripWhitespaceNode": "Strip Whitespace", - # Training dataset nodes - "MakeTrainingDataset": "Make Training Dataset", - "SaveTrainingDataset": "Save Training Dataset", - "LoadTrainingDataset": "Load Training Dataset", -} + +class DatasetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # Data loading/saving nodes + LoadImageDataSetFromFolderNode, + LoadImageTextDataSetFromFolderNode, + SaveImageDataSetToFolderNode, + SaveImageTextDataSetToFolderNode, + # Image transform nodes + ResizeImagesToSameSizeNode, + ResizeImagesToPixelCountNode, + ResizeImagesByShorterEdgeNode, + ResizeImagesByLongerEdgeNode, + CenterCropImagesNode, + RandomCropImagesNode, + FlipImagesNode, + NormalizeImagesNode, + AdjustBrightnessNode, + AdjustContrastNode, + ShuffleDatasetNode, + ShuffleImageTextDatasetNode, + # Text transform nodes + TextToLowercaseNode, + TextToUppercaseNode, + TruncateTextNode, + AddTextPrefixNode, + AddTextSuffixNode, + ReplaceTextNode, + StripWhitespaceNode, + # Training dataset nodes + MakeTrainingDataset, + SaveTrainingDataset, + LoadTrainingDataset, + ] + + +async def comfy_entrypoint() -> DatasetExtension: + return DatasetExtension()