diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 59b5254e2..38b65bfab 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -309,29 +309,71 @@ def pil_to_tensor(img): class ImageProcessingNode(io.ComfyNode): - """Base class for image processing nodes that operate on lists of images. + """Base class for image processing nodes that operate on images. Child classes should set: node_id: Unique node identifier (required) display_name: Display name (optional, defaults to node_id) description: Node description (optional) extra_inputs: List of additional io.Input objects beyond "images" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) - Child classes must implement: - _process(cls, images, **kwargs) -> list[tensor] + Child classes must implement ONE of: + _process(cls, image, **kwargs) -> tensor (for single-item processing) + _group_process(cls, images, **kwargs) -> list[tensor] (for group processing) """ node_id = None display_name = None description = None extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = True # Configurable output mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden + base_class = ImageProcessingNode + has_process = cls._process is not base_class._process + has_group = cls._group_process is not base_class._group_process + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group @classmethod def define_schema(cls): if cls.node_id is None: raise NotImplementedError(f"{cls.__name__} must set node_id class variable") - inputs = [io.Image.Input("images", tooltip="List of images to process.")] + is_group = cls._detect_processing_mode() + + inputs = [ + io.Image.Input( + "images", + tooltip=( + "List of images to process." if is_group else "Image to process." + ), + ) + ] inputs.extend(cls.extra_inputs) return io.Schema( @@ -339,12 +381,12 @@ class ImageProcessingNode(io.ComfyNode): display_name=cls.display_name or cls.node_id, category="dataset/image", is_experimental=True, - is_input_list=True, + is_input_list=is_group, # True for group, False for individual inputs=inputs, outputs=[ io.Image.Output( display_name="images", - is_output_list=True, + is_output_list=cls.is_output_list, tooltip="Processed images", ) ], @@ -352,8 +394,10 @@ class ImageProcessingNode(io.ComfyNode): @classmethod def execute(cls, images, **kwargs): - """Execute the node. Extracts scalar values and calls _process.""" - # Extract scalar values from lists (due to is_input_list=True) + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters params = {} for k, v in kwargs.items(): if isinstance(v, list) and len(v) == 1: @@ -361,12 +405,37 @@ class ImageProcessingNode(io.ComfyNode): else: params[k] = v - result = cls._process(images, **params) - return io.NodeOutput(result) + if is_group: + # Group processing: images is list, call _group_process + result = cls._group_process(images, **params) + else: + # Individual processing: images is single item, call _process + result = cls._process(images, **params) + + # Wrap result based on is_output_list + if cls.is_output_list: + # Result should already be a list (or will be for individual) + return io.NodeOutput(result if is_group else [result]) + else: + # Single output - wrap in list for NodeOutput + return io.NodeOutput([result]) @classmethod - def _process(cls, images, **kwargs): - """Override this method in subclasses to implement specific processing. + def _process(cls, image, **kwargs): + """Override this method for single-item processing. + + Args: + image: tensor - Single image tensor + **kwargs: Additional parameters (already extracted from lists) + + Returns: + tensor - Processed image + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, images, **kwargs): + """Override this method for group processing. Args: images: list[tensor] - List of image tensors @@ -375,33 +444,75 @@ class ImageProcessingNode(io.ComfyNode): Returns: list[tensor] - Processed images """ - raise NotImplementedError(f"{cls.__name__} must implement _process method") + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) class TextProcessingNode(io.ComfyNode): - """Base class for text processing nodes that operate on lists of texts. + """Base class for text processing nodes that operate on texts. Child classes should set: node_id: Unique node identifier (required) display_name: Display name (optional, defaults to node_id) description: Node description (optional) extra_inputs: List of additional io.Input objects beyond "texts" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) - Child classes must implement: - _process(cls, texts, **kwargs) -> list[str] + Child classes must implement ONE of: + _process(cls, text, **kwargs) -> str (for single-item processing) + _group_process(cls, texts, **kwargs) -> list[str] (for group processing) """ node_id = None display_name = None description = None extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = True # Configurable output mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden + base_class = TextProcessingNode + has_process = cls._process is not base_class._process + has_group = cls._group_process is not base_class._group_process + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group @classmethod def define_schema(cls): if cls.node_id is None: raise NotImplementedError(f"{cls.__name__} must set node_id class variable") - inputs = [io.String.Input("texts", tooltip="List of texts to process.")] + is_group = cls._detect_processing_mode() + + inputs = [ + io.String.Input( + "texts", + tooltip="List of texts to process." if is_group else "Text to process.", + ) + ] inputs.extend(cls.extra_inputs) return io.Schema( @@ -409,19 +520,23 @@ class TextProcessingNode(io.ComfyNode): display_name=cls.display_name or cls.node_id, category="dataset/text", is_experimental=True, - is_input_list=True, + is_input_list=is_group, # True for group, False for individual inputs=inputs, outputs=[ io.String.Output( - display_name="texts", is_output_list=True, tooltip="Processed texts" + display_name="texts", + is_output_list=cls.is_output_list, + tooltip="Processed texts", ) ], ) @classmethod def execute(cls, texts, **kwargs): - """Execute the node. Extracts scalar values and calls _process.""" - # Extract scalar values from lists (due to is_input_list=True) + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters params = {} for k, v in kwargs.items(): if isinstance(v, list) and len(v) == 1: @@ -429,12 +544,37 @@ class TextProcessingNode(io.ComfyNode): else: params[k] = v - result = cls._process(texts, **params) - return io.NodeOutput(result) + if is_group: + # Group processing: texts is list, call _group_process + result = cls._group_process(texts, **params) + else: + # Individual processing: texts is single item, call _process + result = cls._process(texts, **params) + + # Wrap result based on is_output_list + if cls.is_output_list: + # Result should already be a list (or will be for individual) + return io.NodeOutput(result if is_group else [result]) + else: + # Single output - wrap in list for NodeOutput + return io.NodeOutput([result]) @classmethod - def _process(cls, texts, **kwargs): - """Override this method in subclasses to implement specific processing. + def _process(cls, text, **kwargs): + """Override this method for single-item processing. + + Args: + text: str - Single text string + **kwargs: Additional parameters (already extracted from lists) + + Returns: + str - Processed text + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, texts, **kwargs): + """Override this method for group processing. Args: texts: list[str] - List of text strings @@ -443,7 +583,9 @@ class TextProcessingNode(io.ComfyNode): Returns: list[str] - Processed texts """ - raise NotImplementedError(f"{cls.__name__} must implement _process method") + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) # ========== Image Transform Nodes ========== @@ -465,31 +607,28 @@ class ResizeImagesToSameSizeNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, width, height, mode): - output_images = [] - for img_tensor in images: - img = tensor_to_pil(img_tensor) + def _process(cls, image, width, height, mode): + img = tensor_to_pil(image) - if mode == "stretch": + if mode == "stretch": + img = img.resize((width, height), Image.Resampling.LANCZOS) + elif mode == "crop_center": + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + if img.width != width or img.height != height: img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "crop_center": - left = max(0, (img.width - width) // 2) - top = max(0, (img.height - height) // 2) - right = min(img.width, left + width) - bottom = min(img.height, top + height) - img = img.crop((left, top, right, bottom)) - if img.width != width or img.height != height: - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "pad": - img.thumbnail((width, height), Image.Resampling.LANCZOS) - new_img = Image.new("RGB", (width, height), (0, 0, 0)) - paste_x = (width - img.width) // 2 - paste_y = (height - img.height) // 2 - new_img.paste(img, (paste_x, paste_y)) - img = new_img + elif mode == "pad": + img.thumbnail((width, height), Image.Resampling.LANCZOS) + new_img = Image.new("RGB", (width, height), (0, 0, 0)) + paste_x = (width - img.width) // 2 + paste_y = (height - img.height) // 2 + new_img.paste(img, (paste_x, paste_y)) + img = new_img - output_images.append(pil_to_tensor(img)) - return output_images + return pil_to_tensor(img) class ResizeImagesToPixelCountNode(ImageProcessingNode): @@ -514,18 +653,15 @@ class ResizeImagesToPixelCountNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, pixel_count, steps): - output_images = [] - for img_tensor in images: - img = tensor_to_pil(img_tensor) - w, h = img.size - pixel_count_ratio = math.sqrt(pixel_count / (w * h)) - new_w = int(h * pixel_count_ratio / steps) * steps - new_h = int(w * pixel_count_ratio / steps) * steps - logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - output_images.append(pil_to_tensor(img)) - return output_images + def _process(cls, image, pixel_count, steps): + img = tensor_to_pil(image) + w, h = img.size + pixel_count_ratio = math.sqrt(pixel_count / (w * h)) + new_w = int(w * pixel_count_ratio / steps) * steps + new_h = int(h * pixel_count_ratio / steps) * steps + logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) class ResizeImagesByShorterEdgeNode(ImageProcessingNode): @@ -543,20 +679,17 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, shorter_edge): - output_images = [] - for img_tensor in images: - img = tensor_to_pil(img_tensor) - w, h = img.size - if w < h: - new_w = shorter_edge - new_h = int(h * (shorter_edge / w)) - else: - new_h = shorter_edge - new_w = int(w * (shorter_edge / h)) - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - output_images.append(pil_to_tensor(img)) - return output_images + def _process(cls, image, shorter_edge): + img = tensor_to_pil(image) + w, h = img.size + if w < h: + new_w = shorter_edge + new_h = int(h * (shorter_edge / w)) + else: + new_h = shorter_edge + new_w = int(w * (shorter_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) class ResizeImagesByLongerEdgeNode(ImageProcessingNode): @@ -574,20 +707,17 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, longer_edge): - output_images = [] - for img_tensor in images: - img = tensor_to_pil(img_tensor) - w, h = img.size - if w > h: - new_w = longer_edge - new_h = int(h * (longer_edge / w)) - else: - new_h = longer_edge - new_w = int(w * (longer_edge / h)) - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - output_images.append(pil_to_tensor(img)) - return output_images + def _process(cls, image, longer_edge): + img = tensor_to_pil(image) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) class CenterCropImagesNode(ImageProcessingNode): @@ -600,17 +730,14 @@ class CenterCropImagesNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, width, height): - output_images = [] - for img_tensor in images: - img = tensor_to_pil(img_tensor) - left = max(0, (img.width - width) // 2) - top = max(0, (img.height - height) // 2) - right = min(img.width, left + width) - bottom = min(img.height, top + height) - img = img.crop((left, top, right, bottom)) - output_images.append(pil_to_tensor(img)) - return output_images + def _process(cls, image, width, height): + img = tensor_to_pil(image) + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) class RandomCropImagesNode(ImageProcessingNode): @@ -628,20 +755,17 @@ class RandomCropImagesNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, width, height, seed): + def _process(cls, image, width, height, seed): np.random.seed(seed % (2**32 - 1)) - output_images = [] - for img_tensor in images: - img = tensor_to_pil(img_tensor) - max_left = max(0, img.width - width) - max_top = max(0, img.height - height) - left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 - top = np.random.randint(0, max_top + 1) if max_top > 0 else 0 - right = min(img.width, left + width) - bottom = min(img.height, top + height) - img = img.crop((left, top, right, bottom)) - output_images.append(pil_to_tensor(img)) - return output_images + img = tensor_to_pil(image) + max_left = max(0, img.width - width) + max_top = max(0, img.height - height) + left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 + top = np.random.randint(0, max_top + 1) if max_top > 0 else 0 + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) class FlipImagesNode(ImageProcessingNode): @@ -658,16 +782,13 @@ class FlipImagesNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, direction): - output_images = [] - for img_tensor in images: - img = tensor_to_pil(img_tensor) - if direction == "horizontal": - img = img.transpose(Image.FLIP_LEFT_RIGHT) - else: - img = img.transpose(Image.FLIP_TOP_BOTTOM) - output_images.append(pil_to_tensor(img)) - return output_images + def _process(cls, image, direction): + img = tensor_to_pil(image) + if direction == "horizontal": + img = img.transpose(Image.FLIP_LEFT_RIGHT) + else: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + return pil_to_tensor(img) class NormalizeImagesNode(ImageProcessingNode): @@ -692,8 +813,8 @@ class NormalizeImagesNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, mean, std): - return [(img - mean) / std for img in images] + def _process(cls, image, mean, std): + return (image - mean) / std class AdjustBrightnessNode(ImageProcessingNode): @@ -711,8 +832,8 @@ class AdjustBrightnessNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, factor): - return [(img * factor).clamp(0.0, 1.0) for img in images] + def _process(cls, image, factor): + return (image * factor).clamp(0.0, 1.0) class AdjustContrastNode(ImageProcessingNode): @@ -730,14 +851,15 @@ class AdjustContrastNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, factor): - return [((img - 0.5) * factor + 0.5).clamp(0.0, 1.0) for img in images] + def _process(cls, image, factor): + return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0) class ShuffleDatasetNode(ImageProcessingNode): node_id = "ShuffleDataset" display_name = "Shuffle Image Dataset" description = "Randomly shuffle the order of images in the dataset." + is_group_process = True # Requires full list to shuffle extra_inputs = [ io.Int.Input( "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." @@ -745,7 +867,7 @@ class ShuffleDatasetNode(ImageProcessingNode): ] @classmethod - def _process(cls, images, seed): + def _group_process(cls, images, seed): np.random.seed(seed % (2**32 - 1)) indices = np.random.permutation(len(images)) return [images[i] for i in indices] @@ -804,8 +926,8 @@ class TextToLowercaseNode(TextProcessingNode): description = "Convert all texts to lowercase." @classmethod - def _process(cls, texts): - return [text.lower() for text in texts] + def _process(cls, text): + return text.lower() class TextToUppercaseNode(TextProcessingNode): @@ -814,8 +936,8 @@ class TextToUppercaseNode(TextProcessingNode): description = "Convert all texts to uppercase." @classmethod - def _process(cls, texts): - return [text.upper() for text in texts] + def _process(cls, text): + return text.upper() class TruncateTextNode(TextProcessingNode): @@ -829,8 +951,8 @@ class TruncateTextNode(TextProcessingNode): ] @classmethod - def _process(cls, texts, max_length): - return [text[:max_length] for text in texts] + def _process(cls, text, max_length): + return text[:max_length] class AddTextPrefixNode(TextProcessingNode): @@ -842,8 +964,8 @@ class AddTextPrefixNode(TextProcessingNode): ] @classmethod - def _process(cls, texts, prefix): - return [prefix + text for text in texts] + def _process(cls, text, prefix): + return prefix + text class AddTextSuffixNode(TextProcessingNode): @@ -855,8 +977,8 @@ class AddTextSuffixNode(TextProcessingNode): ] @classmethod - def _process(cls, texts, suffix): - return [text + suffix for text in texts] + def _process(cls, text, suffix): + return text + suffix class ReplaceTextNode(TextProcessingNode): @@ -869,8 +991,8 @@ class ReplaceTextNode(TextProcessingNode): ] @classmethod - def _process(cls, texts, find, replace): - return [text.replace(find, replace) for text in texts] + def _process(cls, text, find, replace): + return text.replace(find, replace) class StripWhitespaceNode(TextProcessingNode): @@ -879,8 +1001,189 @@ class StripWhitespaceNode(TextProcessingNode): description = "Strip leading and trailing whitespace from all texts." @classmethod - def _process(cls, texts): - return [text.strip() for text in texts] + def _process(cls, text): + return text.strip() + + +# ========== Group Processing Example Nodes ========== + + +class ImageDeduplicationNode(ImageProcessingNode): + """Remove duplicate or very similar images from the dataset using perceptual hashing.""" + + node_id = "ImageDeduplication" + display_name = "Image Deduplication" + description = "Remove duplicate or very similar images from the dataset." + is_group_process = True # Requires full list to compare images + extra_inputs = [ + io.Float.Input( + "similarity_threshold", + default=0.95, + min=0.0, + max=1.0, + tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.", + ), + ] + + @classmethod + def _group_process(cls, images, similarity_threshold): + """Remove duplicate images using perceptual hashing.""" + if len(images) == 0: + return [] + + # Compute simple perceptual hash for each image + def compute_hash(img_tensor): + """Compute a simple perceptual hash by resizing to 8x8 and comparing to average.""" + img = tensor_to_pil(img_tensor) + # Resize to 8x8 + img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L") + # Get pixels + pixels = list(img_small.getdata()) + # Compute average + avg = sum(pixels) / len(pixels) + # Create hash (1 if above average, 0 otherwise) + hash_bits = "".join("1" if p > avg else "0" for p in pixels) + return hash_bits + + def hamming_distance(hash1, hash2): + """Compute Hamming distance between two hash strings.""" + return sum(c1 != c2 for c1, c2 in zip(hash1, hash2)) + + # Compute hashes for all images + hashes = [compute_hash(img) for img in images] + + # Find duplicates + keep_indices = [] + for i in range(len(images)): + is_duplicate = False + for j in keep_indices: + # Compare hashes + distance = hamming_distance(hashes[i], hashes[j]) + similarity = 1.0 - (distance / 64.0) # 64 bits total + if similarity >= similarity_threshold: + is_duplicate = True + logging.info( + f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping" + ) + break + + if not is_duplicate: + keep_indices.append(i) + + # Return only unique images + unique_images = [images[i] for i in keep_indices] + logging.info( + f"Deduplication: kept {len(unique_images)} out of {len(images)} images" + ) + return unique_images + + +class ImageGridNode(ImageProcessingNode): + """Combine multiple images into a single grid/collage.""" + + node_id = "ImageGrid" + display_name = "Image Grid" + description = "Arrange multiple images into a grid layout." + is_group_process = True # Requires full list to create grid + is_output_list = False # Outputs single grid image + extra_inputs = [ + io.Int.Input( + "columns", + default=4, + min=1, + max=20, + tooltip="Number of columns in the grid.", + ), + io.Int.Input( + "cell_width", + default=256, + min=32, + max=2048, + tooltip="Width of each cell in the grid.", + ), + io.Int.Input( + "cell_height", + default=256, + min=32, + max=2048, + tooltip="Height of each cell in the grid.", + ), + io.Int.Input( + "padding", default=4, min=0, max=50, tooltip="Padding between images." + ), + ] + + @classmethod + def _group_process(cls, images, columns, cell_width, cell_height, padding): + """Arrange images into a grid.""" + if len(images) == 0: + raise ValueError("Cannot create grid from empty image list") + + # Calculate grid dimensions + num_images = len(images) + rows = (num_images + columns - 1) // columns # Ceiling division + + # Calculate total grid size + grid_width = columns * cell_width + (columns - 1) * padding + grid_height = rows * cell_height + (rows - 1) * padding + + # Create blank grid + grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0)) + + # Place images + for idx, img_tensor in enumerate(images): + row = idx // columns + col = idx % columns + + # Convert to PIL and resize to cell size + img = tensor_to_pil(img_tensor) + img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS) + + # Calculate position + x = col * (cell_width + padding) + y = row * (cell_height + padding) + + # Paste into grid + grid.paste(img, (x, y)) + + logging.info( + f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})" + ) + return pil_to_tensor(grid) + + +class MergeImageListsNode(ImageProcessingNode): + """Merge multiple image lists into a single list.""" + + node_id = "MergeImageLists" + display_name = "Merge Image Lists" + description = "Concatenate multiple image lists into one." + is_group_process = True # Receives images as list + + @classmethod + def _group_process(cls, images): + """Simply return the images list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged image list contains {len(images)} images") + return images + + +class MergeTextListsNode(TextProcessingNode): + """Merge multiple text lists into a single list.""" + + node_id = "MergeTextLists" + display_name = "Merge Text Lists" + description = "Concatenate multiple text lists into one." + is_group_process = True # Receives texts as list + + @classmethod + def _group_process(cls, texts): + """Simply return the texts list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged text list contains {len(texts)} texts") + return texts # ========== Training Dataset Nodes ========== @@ -1182,6 +1485,11 @@ class DatasetExtension(ComfyExtension): AddTextSuffixNode, ReplaceTextNode, StripWhitespaceNode, + # Group processing examples + ImageDeduplicationNode, + ImageGridNode, + MergeImageListsNode, + MergeTextListsNode, # Training dataset nodes MakeTrainingDataset, SaveTrainingDataset,