use single process instead of input list when no need

This commit is contained in:
Kohaku-Blueleaf 2025-11-18 20:19:18 +08:00
parent 9217d54a8c
commit 0318d2d60c

View File

@ -309,29 +309,71 @@ def pil_to_tensor(img):
class ImageProcessingNode(io.ComfyNode):
"""Base class for image processing nodes that operate on lists of images.
"""Base class for image processing nodes that operate on images.
Child classes should set:
node_id: Unique node identifier (required)
display_name: Display name (optional, defaults to node_id)
description: Node description (optional)
extra_inputs: List of additional io.Input objects beyond "images" (optional)
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
is_output_list: True (list output) or False (single output) (optional, default True)
Child classes must implement:
_process(cls, images, **kwargs) -> list[tensor]
Child classes must implement ONE of:
_process(cls, image, **kwargs) -> tensor (for single-item processing)
_group_process(cls, images, **kwargs) -> list[tensor] (for group processing)
"""
node_id = None
display_name = None
description = None
extra_inputs = []
is_group_process = None # None = auto-detect, True/False = explicit
is_output_list = True # Configurable output mode
@classmethod
def _detect_processing_mode(cls):
"""Detect whether this node uses group or individual processing.
Returns:
bool: True if group processing, False if individual processing
"""
# Explicit setting takes precedence
if cls.is_group_process is not None:
return cls.is_group_process
# Check which method is overridden
base_class = ImageProcessingNode
has_process = cls._process is not base_class._process
has_group = cls._group_process is not base_class._group_process
if has_process and has_group:
raise ValueError(
f"{cls.__name__}: Cannot override both _process and _group_process. "
"Override only one, or set is_group_process explicitly."
)
if not has_process and not has_group:
raise ValueError(
f"{cls.__name__}: Must override either _process or _group_process"
)
return has_group
@classmethod
def define_schema(cls):
if cls.node_id is None:
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
inputs = [io.Image.Input("images", tooltip="List of images to process.")]
is_group = cls._detect_processing_mode()
inputs = [
io.Image.Input(
"images",
tooltip=(
"List of images to process." if is_group else "Image to process."
),
)
]
inputs.extend(cls.extra_inputs)
return io.Schema(
@ -339,12 +381,12 @@ class ImageProcessingNode(io.ComfyNode):
display_name=cls.display_name or cls.node_id,
category="dataset/image",
is_experimental=True,
is_input_list=True,
is_input_list=is_group, # True for group, False for individual
inputs=inputs,
outputs=[
io.Image.Output(
display_name="images",
is_output_list=True,
is_output_list=cls.is_output_list,
tooltip="Processed images",
)
],
@ -352,8 +394,10 @@ class ImageProcessingNode(io.ComfyNode):
@classmethod
def execute(cls, images, **kwargs):
"""Execute the node. Extracts scalar values and calls _process."""
# Extract scalar values from lists (due to is_input_list=True)
"""Execute the node. Routes to _process or _group_process based on mode."""
is_group = cls._detect_processing_mode()
# Extract scalar values from lists for parameters
params = {}
for k, v in kwargs.items():
if isinstance(v, list) and len(v) == 1:
@ -361,12 +405,37 @@ class ImageProcessingNode(io.ComfyNode):
else:
params[k] = v
result = cls._process(images, **params)
return io.NodeOutput(result)
if is_group:
# Group processing: images is list, call _group_process
result = cls._group_process(images, **params)
else:
# Individual processing: images is single item, call _process
result = cls._process(images, **params)
# Wrap result based on is_output_list
if cls.is_output_list:
# Result should already be a list (or will be for individual)
return io.NodeOutput(result if is_group else [result])
else:
# Single output - wrap in list for NodeOutput
return io.NodeOutput([result])
@classmethod
def _process(cls, images, **kwargs):
"""Override this method in subclasses to implement specific processing.
def _process(cls, image, **kwargs):
"""Override this method for single-item processing.
Args:
image: tensor - Single image tensor
**kwargs: Additional parameters (already extracted from lists)
Returns:
tensor - Processed image
"""
raise NotImplementedError(f"{cls.__name__} must implement _process method")
@classmethod
def _group_process(cls, images, **kwargs):
"""Override this method for group processing.
Args:
images: list[tensor] - List of image tensors
@ -375,33 +444,75 @@ class ImageProcessingNode(io.ComfyNode):
Returns:
list[tensor] - Processed images
"""
raise NotImplementedError(f"{cls.__name__} must implement _process method")
raise NotImplementedError(
f"{cls.__name__} must implement _group_process method"
)
class TextProcessingNode(io.ComfyNode):
"""Base class for text processing nodes that operate on lists of texts.
"""Base class for text processing nodes that operate on texts.
Child classes should set:
node_id: Unique node identifier (required)
display_name: Display name (optional, defaults to node_id)
description: Node description (optional)
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
is_output_list: True (list output) or False (single output) (optional, default True)
Child classes must implement:
_process(cls, texts, **kwargs) -> list[str]
Child classes must implement ONE of:
_process(cls, text, **kwargs) -> str (for single-item processing)
_group_process(cls, texts, **kwargs) -> list[str] (for group processing)
"""
node_id = None
display_name = None
description = None
extra_inputs = []
is_group_process = None # None = auto-detect, True/False = explicit
is_output_list = True # Configurable output mode
@classmethod
def _detect_processing_mode(cls):
"""Detect whether this node uses group or individual processing.
Returns:
bool: True if group processing, False if individual processing
"""
# Explicit setting takes precedence
if cls.is_group_process is not None:
return cls.is_group_process
# Check which method is overridden
base_class = TextProcessingNode
has_process = cls._process is not base_class._process
has_group = cls._group_process is not base_class._group_process
if has_process and has_group:
raise ValueError(
f"{cls.__name__}: Cannot override both _process and _group_process. "
"Override only one, or set is_group_process explicitly."
)
if not has_process and not has_group:
raise ValueError(
f"{cls.__name__}: Must override either _process or _group_process"
)
return has_group
@classmethod
def define_schema(cls):
if cls.node_id is None:
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
inputs = [io.String.Input("texts", tooltip="List of texts to process.")]
is_group = cls._detect_processing_mode()
inputs = [
io.String.Input(
"texts",
tooltip="List of texts to process." if is_group else "Text to process.",
)
]
inputs.extend(cls.extra_inputs)
return io.Schema(
@ -409,19 +520,23 @@ class TextProcessingNode(io.ComfyNode):
display_name=cls.display_name or cls.node_id,
category="dataset/text",
is_experimental=True,
is_input_list=True,
is_input_list=is_group, # True for group, False for individual
inputs=inputs,
outputs=[
io.String.Output(
display_name="texts", is_output_list=True, tooltip="Processed texts"
display_name="texts",
is_output_list=cls.is_output_list,
tooltip="Processed texts",
)
],
)
@classmethod
def execute(cls, texts, **kwargs):
"""Execute the node. Extracts scalar values and calls _process."""
# Extract scalar values from lists (due to is_input_list=True)
"""Execute the node. Routes to _process or _group_process based on mode."""
is_group = cls._detect_processing_mode()
# Extract scalar values from lists for parameters
params = {}
for k, v in kwargs.items():
if isinstance(v, list) and len(v) == 1:
@ -429,12 +544,37 @@ class TextProcessingNode(io.ComfyNode):
else:
params[k] = v
result = cls._process(texts, **params)
return io.NodeOutput(result)
if is_group:
# Group processing: texts is list, call _group_process
result = cls._group_process(texts, **params)
else:
# Individual processing: texts is single item, call _process
result = cls._process(texts, **params)
# Wrap result based on is_output_list
if cls.is_output_list:
# Result should already be a list (or will be for individual)
return io.NodeOutput(result if is_group else [result])
else:
# Single output - wrap in list for NodeOutput
return io.NodeOutput([result])
@classmethod
def _process(cls, texts, **kwargs):
"""Override this method in subclasses to implement specific processing.
def _process(cls, text, **kwargs):
"""Override this method for single-item processing.
Args:
text: str - Single text string
**kwargs: Additional parameters (already extracted from lists)
Returns:
str - Processed text
"""
raise NotImplementedError(f"{cls.__name__} must implement _process method")
@classmethod
def _group_process(cls, texts, **kwargs):
"""Override this method for group processing.
Args:
texts: list[str] - List of text strings
@ -443,7 +583,9 @@ class TextProcessingNode(io.ComfyNode):
Returns:
list[str] - Processed texts
"""
raise NotImplementedError(f"{cls.__name__} must implement _process method")
raise NotImplementedError(
f"{cls.__name__} must implement _group_process method"
)
# ========== Image Transform Nodes ==========
@ -465,31 +607,28 @@ class ResizeImagesToSameSizeNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, width, height, mode):
output_images = []
for img_tensor in images:
img = tensor_to_pil(img_tensor)
def _process(cls, image, width, height, mode):
img = tensor_to_pil(image)
if mode == "stretch":
if mode == "stretch":
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "crop_center":
left = max(0, (img.width - width) // 2)
top = max(0, (img.height - height) // 2)
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
if img.width != width or img.height != height:
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "crop_center":
left = max(0, (img.width - width) // 2)
top = max(0, (img.height - height) // 2)
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
if img.width != width or img.height != height:
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "pad":
img.thumbnail((width, height), Image.Resampling.LANCZOS)
new_img = Image.new("RGB", (width, height), (0, 0, 0))
paste_x = (width - img.width) // 2
paste_y = (height - img.height) // 2
new_img.paste(img, (paste_x, paste_y))
img = new_img
elif mode == "pad":
img.thumbnail((width, height), Image.Resampling.LANCZOS)
new_img = Image.new("RGB", (width, height), (0, 0, 0))
paste_x = (width - img.width) // 2
paste_y = (height - img.height) // 2
new_img.paste(img, (paste_x, paste_y))
img = new_img
output_images.append(pil_to_tensor(img))
return output_images
return pil_to_tensor(img)
class ResizeImagesToPixelCountNode(ImageProcessingNode):
@ -514,18 +653,15 @@ class ResizeImagesToPixelCountNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, pixel_count, steps):
output_images = []
for img_tensor in images:
img = tensor_to_pil(img_tensor)
w, h = img.size
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
new_w = int(h * pixel_count_ratio / steps) * steps
new_h = int(w * pixel_count_ratio / steps) * steps
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
output_images.append(pil_to_tensor(img))
return output_images
def _process(cls, image, pixel_count, steps):
img = tensor_to_pil(image)
w, h = img.size
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
new_w = int(w * pixel_count_ratio / steps) * steps
new_h = int(h * pixel_count_ratio / steps) * steps
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
return pil_to_tensor(img)
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
@ -543,20 +679,17 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, shorter_edge):
output_images = []
for img_tensor in images:
img = tensor_to_pil(img_tensor)
w, h = img.size
if w < h:
new_w = shorter_edge
new_h = int(h * (shorter_edge / w))
else:
new_h = shorter_edge
new_w = int(w * (shorter_edge / h))
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
output_images.append(pil_to_tensor(img))
return output_images
def _process(cls, image, shorter_edge):
img = tensor_to_pil(image)
w, h = img.size
if w < h:
new_w = shorter_edge
new_h = int(h * (shorter_edge / w))
else:
new_h = shorter_edge
new_w = int(w * (shorter_edge / h))
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
return pil_to_tensor(img)
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
@ -574,20 +707,17 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, longer_edge):
output_images = []
for img_tensor in images:
img = tensor_to_pil(img_tensor)
w, h = img.size
if w > h:
new_w = longer_edge
new_h = int(h * (longer_edge / w))
else:
new_h = longer_edge
new_w = int(w * (longer_edge / h))
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
output_images.append(pil_to_tensor(img))
return output_images
def _process(cls, image, longer_edge):
img = tensor_to_pil(image)
w, h = img.size
if w > h:
new_w = longer_edge
new_h = int(h * (longer_edge / w))
else:
new_h = longer_edge
new_w = int(w * (longer_edge / h))
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
return pil_to_tensor(img)
class CenterCropImagesNode(ImageProcessingNode):
@ -600,17 +730,14 @@ class CenterCropImagesNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, width, height):
output_images = []
for img_tensor in images:
img = tensor_to_pil(img_tensor)
left = max(0, (img.width - width) // 2)
top = max(0, (img.height - height) // 2)
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
output_images.append(pil_to_tensor(img))
return output_images
def _process(cls, image, width, height):
img = tensor_to_pil(image)
left = max(0, (img.width - width) // 2)
top = max(0, (img.height - height) // 2)
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
return pil_to_tensor(img)
class RandomCropImagesNode(ImageProcessingNode):
@ -628,20 +755,17 @@ class RandomCropImagesNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, width, height, seed):
def _process(cls, image, width, height, seed):
np.random.seed(seed % (2**32 - 1))
output_images = []
for img_tensor in images:
img = tensor_to_pil(img_tensor)
max_left = max(0, img.width - width)
max_top = max(0, img.height - height)
left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
output_images.append(pil_to_tensor(img))
return output_images
img = tensor_to_pil(image)
max_left = max(0, img.width - width)
max_top = max(0, img.height - height)
left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
return pil_to_tensor(img)
class FlipImagesNode(ImageProcessingNode):
@ -658,16 +782,13 @@ class FlipImagesNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, direction):
output_images = []
for img_tensor in images:
img = tensor_to_pil(img_tensor)
if direction == "horizontal":
img = img.transpose(Image.FLIP_LEFT_RIGHT)
else:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
output_images.append(pil_to_tensor(img))
return output_images
def _process(cls, image, direction):
img = tensor_to_pil(image)
if direction == "horizontal":
img = img.transpose(Image.FLIP_LEFT_RIGHT)
else:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
return pil_to_tensor(img)
class NormalizeImagesNode(ImageProcessingNode):
@ -692,8 +813,8 @@ class NormalizeImagesNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, mean, std):
return [(img - mean) / std for img in images]
def _process(cls, image, mean, std):
return (image - mean) / std
class AdjustBrightnessNode(ImageProcessingNode):
@ -711,8 +832,8 @@ class AdjustBrightnessNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, factor):
return [(img * factor).clamp(0.0, 1.0) for img in images]
def _process(cls, image, factor):
return (image * factor).clamp(0.0, 1.0)
class AdjustContrastNode(ImageProcessingNode):
@ -730,14 +851,15 @@ class AdjustContrastNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, factor):
return [((img - 0.5) * factor + 0.5).clamp(0.0, 1.0) for img in images]
def _process(cls, image, factor):
return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0)
class ShuffleDatasetNode(ImageProcessingNode):
node_id = "ShuffleDataset"
display_name = "Shuffle Image Dataset"
description = "Randomly shuffle the order of images in the dataset."
is_group_process = True # Requires full list to shuffle
extra_inputs = [
io.Int.Input(
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
@ -745,7 +867,7 @@ class ShuffleDatasetNode(ImageProcessingNode):
]
@classmethod
def _process(cls, images, seed):
def _group_process(cls, images, seed):
np.random.seed(seed % (2**32 - 1))
indices = np.random.permutation(len(images))
return [images[i] for i in indices]
@ -804,8 +926,8 @@ class TextToLowercaseNode(TextProcessingNode):
description = "Convert all texts to lowercase."
@classmethod
def _process(cls, texts):
return [text.lower() for text in texts]
def _process(cls, text):
return text.lower()
class TextToUppercaseNode(TextProcessingNode):
@ -814,8 +936,8 @@ class TextToUppercaseNode(TextProcessingNode):
description = "Convert all texts to uppercase."
@classmethod
def _process(cls, texts):
return [text.upper() for text in texts]
def _process(cls, text):
return text.upper()
class TruncateTextNode(TextProcessingNode):
@ -829,8 +951,8 @@ class TruncateTextNode(TextProcessingNode):
]
@classmethod
def _process(cls, texts, max_length):
return [text[:max_length] for text in texts]
def _process(cls, text, max_length):
return text[:max_length]
class AddTextPrefixNode(TextProcessingNode):
@ -842,8 +964,8 @@ class AddTextPrefixNode(TextProcessingNode):
]
@classmethod
def _process(cls, texts, prefix):
return [prefix + text for text in texts]
def _process(cls, text, prefix):
return prefix + text
class AddTextSuffixNode(TextProcessingNode):
@ -855,8 +977,8 @@ class AddTextSuffixNode(TextProcessingNode):
]
@classmethod
def _process(cls, texts, suffix):
return [text + suffix for text in texts]
def _process(cls, text, suffix):
return text + suffix
class ReplaceTextNode(TextProcessingNode):
@ -869,8 +991,8 @@ class ReplaceTextNode(TextProcessingNode):
]
@classmethod
def _process(cls, texts, find, replace):
return [text.replace(find, replace) for text in texts]
def _process(cls, text, find, replace):
return text.replace(find, replace)
class StripWhitespaceNode(TextProcessingNode):
@ -879,8 +1001,189 @@ class StripWhitespaceNode(TextProcessingNode):
description = "Strip leading and trailing whitespace from all texts."
@classmethod
def _process(cls, texts):
return [text.strip() for text in texts]
def _process(cls, text):
return text.strip()
# ========== Group Processing Example Nodes ==========
class ImageDeduplicationNode(ImageProcessingNode):
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
node_id = "ImageDeduplication"
display_name = "Image Deduplication"
description = "Remove duplicate or very similar images from the dataset."
is_group_process = True # Requires full list to compare images
extra_inputs = [
io.Float.Input(
"similarity_threshold",
default=0.95,
min=0.0,
max=1.0,
tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.",
),
]
@classmethod
def _group_process(cls, images, similarity_threshold):
"""Remove duplicate images using perceptual hashing."""
if len(images) == 0:
return []
# Compute simple perceptual hash for each image
def compute_hash(img_tensor):
"""Compute a simple perceptual hash by resizing to 8x8 and comparing to average."""
img = tensor_to_pil(img_tensor)
# Resize to 8x8
img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L")
# Get pixels
pixels = list(img_small.getdata())
# Compute average
avg = sum(pixels) / len(pixels)
# Create hash (1 if above average, 0 otherwise)
hash_bits = "".join("1" if p > avg else "0" for p in pixels)
return hash_bits
def hamming_distance(hash1, hash2):
"""Compute Hamming distance between two hash strings."""
return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
# Compute hashes for all images
hashes = [compute_hash(img) for img in images]
# Find duplicates
keep_indices = []
for i in range(len(images)):
is_duplicate = False
for j in keep_indices:
# Compare hashes
distance = hamming_distance(hashes[i], hashes[j])
similarity = 1.0 - (distance / 64.0) # 64 bits total
if similarity >= similarity_threshold:
is_duplicate = True
logging.info(
f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping"
)
break
if not is_duplicate:
keep_indices.append(i)
# Return only unique images
unique_images = [images[i] for i in keep_indices]
logging.info(
f"Deduplication: kept {len(unique_images)} out of {len(images)} images"
)
return unique_images
class ImageGridNode(ImageProcessingNode):
"""Combine multiple images into a single grid/collage."""
node_id = "ImageGrid"
display_name = "Image Grid"
description = "Arrange multiple images into a grid layout."
is_group_process = True # Requires full list to create grid
is_output_list = False # Outputs single grid image
extra_inputs = [
io.Int.Input(
"columns",
default=4,
min=1,
max=20,
tooltip="Number of columns in the grid.",
),
io.Int.Input(
"cell_width",
default=256,
min=32,
max=2048,
tooltip="Width of each cell in the grid.",
),
io.Int.Input(
"cell_height",
default=256,
min=32,
max=2048,
tooltip="Height of each cell in the grid.",
),
io.Int.Input(
"padding", default=4, min=0, max=50, tooltip="Padding between images."
),
]
@classmethod
def _group_process(cls, images, columns, cell_width, cell_height, padding):
"""Arrange images into a grid."""
if len(images) == 0:
raise ValueError("Cannot create grid from empty image list")
# Calculate grid dimensions
num_images = len(images)
rows = (num_images + columns - 1) // columns # Ceiling division
# Calculate total grid size
grid_width = columns * cell_width + (columns - 1) * padding
grid_height = rows * cell_height + (rows - 1) * padding
# Create blank grid
grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0))
# Place images
for idx, img_tensor in enumerate(images):
row = idx // columns
col = idx % columns
# Convert to PIL and resize to cell size
img = tensor_to_pil(img_tensor)
img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS)
# Calculate position
x = col * (cell_width + padding)
y = row * (cell_height + padding)
# Paste into grid
grid.paste(img, (x, y))
logging.info(
f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})"
)
return pil_to_tensor(grid)
class MergeImageListsNode(ImageProcessingNode):
"""Merge multiple image lists into a single list."""
node_id = "MergeImageLists"
display_name = "Merge Image Lists"
description = "Concatenate multiple image lists into one."
is_group_process = True # Receives images as list
@classmethod
def _group_process(cls, images):
"""Simply return the images list (already merged by input handling)."""
# When multiple list inputs are connected, they're concatenated
# For now, this is a simple pass-through
logging.info(f"Merged image list contains {len(images)} images")
return images
class MergeTextListsNode(TextProcessingNode):
"""Merge multiple text lists into a single list."""
node_id = "MergeTextLists"
display_name = "Merge Text Lists"
description = "Concatenate multiple text lists into one."
is_group_process = True # Receives texts as list
@classmethod
def _group_process(cls, texts):
"""Simply return the texts list (already merged by input handling)."""
# When multiple list inputs are connected, they're concatenated
# For now, this is a simple pass-through
logging.info(f"Merged text list contains {len(texts)} texts")
return texts
# ========== Training Dataset Nodes ==========
@ -1182,6 +1485,11 @@ class DatasetExtension(ComfyExtension):
AddTextSuffixNode,
ReplaceTextNode,
StripWhitespaceNode,
# Group processing examples
ImageDeduplicationNode,
ImageGridNode,
MergeImageListsNode,
MergeTextListsNode,
# Training dataset nodes
MakeTrainingDataset,
SaveTrainingDataset,