mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 15:20:51 +08:00
use single process instead of input list when no need
This commit is contained in:
parent
9217d54a8c
commit
0318d2d60c
@ -309,29 +309,71 @@ def pil_to_tensor(img):
|
||||
|
||||
|
||||
class ImageProcessingNode(io.ComfyNode):
|
||||
"""Base class for image processing nodes that operate on lists of images.
|
||||
"""Base class for image processing nodes that operate on images.
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "images" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
|
||||
Child classes must implement:
|
||||
_process(cls, images, **kwargs) -> list[tensor]
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, image, **kwargs) -> tensor (for single-item processing)
|
||||
_group_process(cls, images, **kwargs) -> list[tensor] (for group processing)
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = True # Configurable output mode
|
||||
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
|
||||
Returns:
|
||||
bool: True if group processing, False if individual processing
|
||||
"""
|
||||
# Explicit setting takes precedence
|
||||
if cls.is_group_process is not None:
|
||||
return cls.is_group_process
|
||||
|
||||
# Check which method is overridden
|
||||
base_class = ImageProcessingNode
|
||||
has_process = cls._process is not base_class._process
|
||||
has_group = cls._group_process is not base_class._group_process
|
||||
|
||||
if has_process and has_group:
|
||||
raise ValueError(
|
||||
f"{cls.__name__}: Cannot override both _process and _group_process. "
|
||||
"Override only one, or set is_group_process explicitly."
|
||||
)
|
||||
if not has_process and not has_group:
|
||||
raise ValueError(
|
||||
f"{cls.__name__}: Must override either _process or _group_process"
|
||||
)
|
||||
|
||||
return has_group
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
if cls.node_id is None:
|
||||
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
|
||||
|
||||
inputs = [io.Image.Input("images", tooltip="List of images to process.")]
|
||||
is_group = cls._detect_processing_mode()
|
||||
|
||||
inputs = [
|
||||
io.Image.Input(
|
||||
"images",
|
||||
tooltip=(
|
||||
"List of images to process." if is_group else "Image to process."
|
||||
),
|
||||
)
|
||||
]
|
||||
inputs.extend(cls.extra_inputs)
|
||||
|
||||
return io.Schema(
|
||||
@ -339,12 +381,12 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
display_name=cls.display_name or cls.node_id,
|
||||
category="dataset/image",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
is_input_list=is_group, # True for group, False for individual
|
||||
inputs=inputs,
|
||||
outputs=[
|
||||
io.Image.Output(
|
||||
display_name="images",
|
||||
is_output_list=True,
|
||||
is_output_list=cls.is_output_list,
|
||||
tooltip="Processed images",
|
||||
)
|
||||
],
|
||||
@ -352,8 +394,10 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, **kwargs):
|
||||
"""Execute the node. Extracts scalar values and calls _process."""
|
||||
# Extract scalar values from lists (due to is_input_list=True)
|
||||
"""Execute the node. Routes to _process or _group_process based on mode."""
|
||||
is_group = cls._detect_processing_mode()
|
||||
|
||||
# Extract scalar values from lists for parameters
|
||||
params = {}
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, list) and len(v) == 1:
|
||||
@ -361,12 +405,37 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
else:
|
||||
params[k] = v
|
||||
|
||||
result = cls._process(images, **params)
|
||||
return io.NodeOutput(result)
|
||||
if is_group:
|
||||
# Group processing: images is list, call _group_process
|
||||
result = cls._group_process(images, **params)
|
||||
else:
|
||||
# Individual processing: images is single item, call _process
|
||||
result = cls._process(images, **params)
|
||||
|
||||
# Wrap result based on is_output_list
|
||||
if cls.is_output_list:
|
||||
# Result should already be a list (or will be for individual)
|
||||
return io.NodeOutput(result if is_group else [result])
|
||||
else:
|
||||
# Single output - wrap in list for NodeOutput
|
||||
return io.NodeOutput([result])
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, **kwargs):
|
||||
"""Override this method in subclasses to implement specific processing.
|
||||
def _process(cls, image, **kwargs):
|
||||
"""Override this method for single-item processing.
|
||||
|
||||
Args:
|
||||
image: tensor - Single image tensor
|
||||
**kwargs: Additional parameters (already extracted from lists)
|
||||
|
||||
Returns:
|
||||
tensor - Processed image
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, images, **kwargs):
|
||||
"""Override this method for group processing.
|
||||
|
||||
Args:
|
||||
images: list[tensor] - List of image tensors
|
||||
@ -375,33 +444,75 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
Returns:
|
||||
list[tensor] - Processed images
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
||||
raise NotImplementedError(
|
||||
f"{cls.__name__} must implement _group_process method"
|
||||
)
|
||||
|
||||
|
||||
class TextProcessingNode(io.ComfyNode):
|
||||
"""Base class for text processing nodes that operate on lists of texts.
|
||||
"""Base class for text processing nodes that operate on texts.
|
||||
|
||||
Child classes should set:
|
||||
node_id: Unique node identifier (required)
|
||||
display_name: Display name (optional, defaults to node_id)
|
||||
description: Node description (optional)
|
||||
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
|
||||
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||
|
||||
Child classes must implement:
|
||||
_process(cls, texts, **kwargs) -> list[str]
|
||||
Child classes must implement ONE of:
|
||||
_process(cls, text, **kwargs) -> str (for single-item processing)
|
||||
_group_process(cls, texts, **kwargs) -> list[str] (for group processing)
|
||||
"""
|
||||
|
||||
node_id = None
|
||||
display_name = None
|
||||
description = None
|
||||
extra_inputs = []
|
||||
is_group_process = None # None = auto-detect, True/False = explicit
|
||||
is_output_list = True # Configurable output mode
|
||||
|
||||
@classmethod
|
||||
def _detect_processing_mode(cls):
|
||||
"""Detect whether this node uses group or individual processing.
|
||||
|
||||
Returns:
|
||||
bool: True if group processing, False if individual processing
|
||||
"""
|
||||
# Explicit setting takes precedence
|
||||
if cls.is_group_process is not None:
|
||||
return cls.is_group_process
|
||||
|
||||
# Check which method is overridden
|
||||
base_class = TextProcessingNode
|
||||
has_process = cls._process is not base_class._process
|
||||
has_group = cls._group_process is not base_class._group_process
|
||||
|
||||
if has_process and has_group:
|
||||
raise ValueError(
|
||||
f"{cls.__name__}: Cannot override both _process and _group_process. "
|
||||
"Override only one, or set is_group_process explicitly."
|
||||
)
|
||||
if not has_process and not has_group:
|
||||
raise ValueError(
|
||||
f"{cls.__name__}: Must override either _process or _group_process"
|
||||
)
|
||||
|
||||
return has_group
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
if cls.node_id is None:
|
||||
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
|
||||
|
||||
inputs = [io.String.Input("texts", tooltip="List of texts to process.")]
|
||||
is_group = cls._detect_processing_mode()
|
||||
|
||||
inputs = [
|
||||
io.String.Input(
|
||||
"texts",
|
||||
tooltip="List of texts to process." if is_group else "Text to process.",
|
||||
)
|
||||
]
|
||||
inputs.extend(cls.extra_inputs)
|
||||
|
||||
return io.Schema(
|
||||
@ -409,19 +520,23 @@ class TextProcessingNode(io.ComfyNode):
|
||||
display_name=cls.display_name or cls.node_id,
|
||||
category="dataset/text",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
is_input_list=is_group, # True for group, False for individual
|
||||
inputs=inputs,
|
||||
outputs=[
|
||||
io.String.Output(
|
||||
display_name="texts", is_output_list=True, tooltip="Processed texts"
|
||||
display_name="texts",
|
||||
is_output_list=cls.is_output_list,
|
||||
tooltip="Processed texts",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, texts, **kwargs):
|
||||
"""Execute the node. Extracts scalar values and calls _process."""
|
||||
# Extract scalar values from lists (due to is_input_list=True)
|
||||
"""Execute the node. Routes to _process or _group_process based on mode."""
|
||||
is_group = cls._detect_processing_mode()
|
||||
|
||||
# Extract scalar values from lists for parameters
|
||||
params = {}
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, list) and len(v) == 1:
|
||||
@ -429,12 +544,37 @@ class TextProcessingNode(io.ComfyNode):
|
||||
else:
|
||||
params[k] = v
|
||||
|
||||
result = cls._process(texts, **params)
|
||||
return io.NodeOutput(result)
|
||||
if is_group:
|
||||
# Group processing: texts is list, call _group_process
|
||||
result = cls._group_process(texts, **params)
|
||||
else:
|
||||
# Individual processing: texts is single item, call _process
|
||||
result = cls._process(texts, **params)
|
||||
|
||||
# Wrap result based on is_output_list
|
||||
if cls.is_output_list:
|
||||
# Result should already be a list (or will be for individual)
|
||||
return io.NodeOutput(result if is_group else [result])
|
||||
else:
|
||||
# Single output - wrap in list for NodeOutput
|
||||
return io.NodeOutput([result])
|
||||
|
||||
@classmethod
|
||||
def _process(cls, texts, **kwargs):
|
||||
"""Override this method in subclasses to implement specific processing.
|
||||
def _process(cls, text, **kwargs):
|
||||
"""Override this method for single-item processing.
|
||||
|
||||
Args:
|
||||
text: str - Single text string
|
||||
**kwargs: Additional parameters (already extracted from lists)
|
||||
|
||||
Returns:
|
||||
str - Processed text
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, texts, **kwargs):
|
||||
"""Override this method for group processing.
|
||||
|
||||
Args:
|
||||
texts: list[str] - List of text strings
|
||||
@ -443,7 +583,9 @@ class TextProcessingNode(io.ComfyNode):
|
||||
Returns:
|
||||
list[str] - Processed texts
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
||||
raise NotImplementedError(
|
||||
f"{cls.__name__} must implement _group_process method"
|
||||
)
|
||||
|
||||
|
||||
# ========== Image Transform Nodes ==========
|
||||
@ -465,31 +607,28 @@ class ResizeImagesToSameSizeNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, width, height, mode):
|
||||
output_images = []
|
||||
for img_tensor in images:
|
||||
img = tensor_to_pil(img_tensor)
|
||||
def _process(cls, image, width, height, mode):
|
||||
img = tensor_to_pil(image)
|
||||
|
||||
if mode == "stretch":
|
||||
if mode == "stretch":
|
||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
elif mode == "crop_center":
|
||||
left = max(0, (img.width - width) // 2)
|
||||
top = max(0, (img.height - height) // 2)
|
||||
right = min(img.width, left + width)
|
||||
bottom = min(img.height, top + height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
if img.width != width or img.height != height:
|
||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
elif mode == "crop_center":
|
||||
left = max(0, (img.width - width) // 2)
|
||||
top = max(0, (img.height - height) // 2)
|
||||
right = min(img.width, left + width)
|
||||
bottom = min(img.height, top + height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
if img.width != width or img.height != height:
|
||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
elif mode == "pad":
|
||||
img.thumbnail((width, height), Image.Resampling.LANCZOS)
|
||||
new_img = Image.new("RGB", (width, height), (0, 0, 0))
|
||||
paste_x = (width - img.width) // 2
|
||||
paste_y = (height - img.height) // 2
|
||||
new_img.paste(img, (paste_x, paste_y))
|
||||
img = new_img
|
||||
elif mode == "pad":
|
||||
img.thumbnail((width, height), Image.Resampling.LANCZOS)
|
||||
new_img = Image.new("RGB", (width, height), (0, 0, 0))
|
||||
paste_x = (width - img.width) // 2
|
||||
paste_y = (height - img.height) // 2
|
||||
new_img.paste(img, (paste_x, paste_y))
|
||||
img = new_img
|
||||
|
||||
output_images.append(pil_to_tensor(img))
|
||||
return output_images
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
||||
@ -514,18 +653,15 @@ class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, pixel_count, steps):
|
||||
output_images = []
|
||||
for img_tensor in images:
|
||||
img = tensor_to_pil(img_tensor)
|
||||
w, h = img.size
|
||||
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
|
||||
new_w = int(h * pixel_count_ratio / steps) * steps
|
||||
new_h = int(w * pixel_count_ratio / steps) * steps
|
||||
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
output_images.append(pil_to_tensor(img))
|
||||
return output_images
|
||||
def _process(cls, image, pixel_count, steps):
|
||||
img = tensor_to_pil(image)
|
||||
w, h = img.size
|
||||
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
|
||||
new_w = int(w * pixel_count_ratio / steps) * steps
|
||||
new_h = int(h * pixel_count_ratio / steps) * steps
|
||||
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
@ -543,20 +679,17 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, shorter_edge):
|
||||
output_images = []
|
||||
for img_tensor in images:
|
||||
img = tensor_to_pil(img_tensor)
|
||||
w, h = img.size
|
||||
if w < h:
|
||||
new_w = shorter_edge
|
||||
new_h = int(h * (shorter_edge / w))
|
||||
else:
|
||||
new_h = shorter_edge
|
||||
new_w = int(w * (shorter_edge / h))
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
output_images.append(pil_to_tensor(img))
|
||||
return output_images
|
||||
def _process(cls, image, shorter_edge):
|
||||
img = tensor_to_pil(image)
|
||||
w, h = img.size
|
||||
if w < h:
|
||||
new_w = shorter_edge
|
||||
new_h = int(h * (shorter_edge / w))
|
||||
else:
|
||||
new_h = shorter_edge
|
||||
new_w = int(w * (shorter_edge / h))
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
@ -574,20 +707,17 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, longer_edge):
|
||||
output_images = []
|
||||
for img_tensor in images:
|
||||
img = tensor_to_pil(img_tensor)
|
||||
w, h = img.size
|
||||
if w > h:
|
||||
new_w = longer_edge
|
||||
new_h = int(h * (longer_edge / w))
|
||||
else:
|
||||
new_h = longer_edge
|
||||
new_w = int(w * (longer_edge / h))
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
output_images.append(pil_to_tensor(img))
|
||||
return output_images
|
||||
def _process(cls, image, longer_edge):
|
||||
img = tensor_to_pil(image)
|
||||
w, h = img.size
|
||||
if w > h:
|
||||
new_w = longer_edge
|
||||
new_h = int(h * (longer_edge / w))
|
||||
else:
|
||||
new_h = longer_edge
|
||||
new_w = int(w * (longer_edge / h))
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class CenterCropImagesNode(ImageProcessingNode):
|
||||
@ -600,17 +730,14 @@ class CenterCropImagesNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, width, height):
|
||||
output_images = []
|
||||
for img_tensor in images:
|
||||
img = tensor_to_pil(img_tensor)
|
||||
left = max(0, (img.width - width) // 2)
|
||||
top = max(0, (img.height - height) // 2)
|
||||
right = min(img.width, left + width)
|
||||
bottom = min(img.height, top + height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
output_images.append(pil_to_tensor(img))
|
||||
return output_images
|
||||
def _process(cls, image, width, height):
|
||||
img = tensor_to_pil(image)
|
||||
left = max(0, (img.width - width) // 2)
|
||||
top = max(0, (img.height - height) // 2)
|
||||
right = min(img.width, left + width)
|
||||
bottom = min(img.height, top + height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class RandomCropImagesNode(ImageProcessingNode):
|
||||
@ -628,20 +755,17 @@ class RandomCropImagesNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, width, height, seed):
|
||||
def _process(cls, image, width, height, seed):
|
||||
np.random.seed(seed % (2**32 - 1))
|
||||
output_images = []
|
||||
for img_tensor in images:
|
||||
img = tensor_to_pil(img_tensor)
|
||||
max_left = max(0, img.width - width)
|
||||
max_top = max(0, img.height - height)
|
||||
left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
|
||||
top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
|
||||
right = min(img.width, left + width)
|
||||
bottom = min(img.height, top + height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
output_images.append(pil_to_tensor(img))
|
||||
return output_images
|
||||
img = tensor_to_pil(image)
|
||||
max_left = max(0, img.width - width)
|
||||
max_top = max(0, img.height - height)
|
||||
left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
|
||||
top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
|
||||
right = min(img.width, left + width)
|
||||
bottom = min(img.height, top + height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class FlipImagesNode(ImageProcessingNode):
|
||||
@ -658,16 +782,13 @@ class FlipImagesNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, direction):
|
||||
output_images = []
|
||||
for img_tensor in images:
|
||||
img = tensor_to_pil(img_tensor)
|
||||
if direction == "horizontal":
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
else:
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
output_images.append(pil_to_tensor(img))
|
||||
return output_images
|
||||
def _process(cls, image, direction):
|
||||
img = tensor_to_pil(image)
|
||||
if direction == "horizontal":
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
else:
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class NormalizeImagesNode(ImageProcessingNode):
|
||||
@ -692,8 +813,8 @@ class NormalizeImagesNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, mean, std):
|
||||
return [(img - mean) / std for img in images]
|
||||
def _process(cls, image, mean, std):
|
||||
return (image - mean) / std
|
||||
|
||||
|
||||
class AdjustBrightnessNode(ImageProcessingNode):
|
||||
@ -711,8 +832,8 @@ class AdjustBrightnessNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, factor):
|
||||
return [(img * factor).clamp(0.0, 1.0) for img in images]
|
||||
def _process(cls, image, factor):
|
||||
return (image * factor).clamp(0.0, 1.0)
|
||||
|
||||
|
||||
class AdjustContrastNode(ImageProcessingNode):
|
||||
@ -730,14 +851,15 @@ class AdjustContrastNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, factor):
|
||||
return [((img - 0.5) * factor + 0.5).clamp(0.0, 1.0) for img in images]
|
||||
def _process(cls, image, factor):
|
||||
return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0)
|
||||
|
||||
|
||||
class ShuffleDatasetNode(ImageProcessingNode):
|
||||
node_id = "ShuffleDataset"
|
||||
display_name = "Shuffle Image Dataset"
|
||||
description = "Randomly shuffle the order of images in the dataset."
|
||||
is_group_process = True # Requires full list to shuffle
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
||||
@ -745,7 +867,7 @@ class ShuffleDatasetNode(ImageProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, images, seed):
|
||||
def _group_process(cls, images, seed):
|
||||
np.random.seed(seed % (2**32 - 1))
|
||||
indices = np.random.permutation(len(images))
|
||||
return [images[i] for i in indices]
|
||||
@ -804,8 +926,8 @@ class TextToLowercaseNode(TextProcessingNode):
|
||||
description = "Convert all texts to lowercase."
|
||||
|
||||
@classmethod
|
||||
def _process(cls, texts):
|
||||
return [text.lower() for text in texts]
|
||||
def _process(cls, text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
class TextToUppercaseNode(TextProcessingNode):
|
||||
@ -814,8 +936,8 @@ class TextToUppercaseNode(TextProcessingNode):
|
||||
description = "Convert all texts to uppercase."
|
||||
|
||||
@classmethod
|
||||
def _process(cls, texts):
|
||||
return [text.upper() for text in texts]
|
||||
def _process(cls, text):
|
||||
return text.upper()
|
||||
|
||||
|
||||
class TruncateTextNode(TextProcessingNode):
|
||||
@ -829,8 +951,8 @@ class TruncateTextNode(TextProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, texts, max_length):
|
||||
return [text[:max_length] for text in texts]
|
||||
def _process(cls, text, max_length):
|
||||
return text[:max_length]
|
||||
|
||||
|
||||
class AddTextPrefixNode(TextProcessingNode):
|
||||
@ -842,8 +964,8 @@ class AddTextPrefixNode(TextProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, texts, prefix):
|
||||
return [prefix + text for text in texts]
|
||||
def _process(cls, text, prefix):
|
||||
return prefix + text
|
||||
|
||||
|
||||
class AddTextSuffixNode(TextProcessingNode):
|
||||
@ -855,8 +977,8 @@ class AddTextSuffixNode(TextProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, texts, suffix):
|
||||
return [text + suffix for text in texts]
|
||||
def _process(cls, text, suffix):
|
||||
return text + suffix
|
||||
|
||||
|
||||
class ReplaceTextNode(TextProcessingNode):
|
||||
@ -869,8 +991,8 @@ class ReplaceTextNode(TextProcessingNode):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, texts, find, replace):
|
||||
return [text.replace(find, replace) for text in texts]
|
||||
def _process(cls, text, find, replace):
|
||||
return text.replace(find, replace)
|
||||
|
||||
|
||||
class StripWhitespaceNode(TextProcessingNode):
|
||||
@ -879,8 +1001,189 @@ class StripWhitespaceNode(TextProcessingNode):
|
||||
description = "Strip leading and trailing whitespace from all texts."
|
||||
|
||||
@classmethod
|
||||
def _process(cls, texts):
|
||||
return [text.strip() for text in texts]
|
||||
def _process(cls, text):
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ========== Group Processing Example Nodes ==========
|
||||
|
||||
|
||||
class ImageDeduplicationNode(ImageProcessingNode):
|
||||
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
|
||||
|
||||
node_id = "ImageDeduplication"
|
||||
display_name = "Image Deduplication"
|
||||
description = "Remove duplicate or very similar images from the dataset."
|
||||
is_group_process = True # Requires full list to compare images
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"similarity_threshold",
|
||||
default=0.95,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, images, similarity_threshold):
|
||||
"""Remove duplicate images using perceptual hashing."""
|
||||
if len(images) == 0:
|
||||
return []
|
||||
|
||||
# Compute simple perceptual hash for each image
|
||||
def compute_hash(img_tensor):
|
||||
"""Compute a simple perceptual hash by resizing to 8x8 and comparing to average."""
|
||||
img = tensor_to_pil(img_tensor)
|
||||
# Resize to 8x8
|
||||
img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L")
|
||||
# Get pixels
|
||||
pixels = list(img_small.getdata())
|
||||
# Compute average
|
||||
avg = sum(pixels) / len(pixels)
|
||||
# Create hash (1 if above average, 0 otherwise)
|
||||
hash_bits = "".join("1" if p > avg else "0" for p in pixels)
|
||||
return hash_bits
|
||||
|
||||
def hamming_distance(hash1, hash2):
|
||||
"""Compute Hamming distance between two hash strings."""
|
||||
return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
|
||||
|
||||
# Compute hashes for all images
|
||||
hashes = [compute_hash(img) for img in images]
|
||||
|
||||
# Find duplicates
|
||||
keep_indices = []
|
||||
for i in range(len(images)):
|
||||
is_duplicate = False
|
||||
for j in keep_indices:
|
||||
# Compare hashes
|
||||
distance = hamming_distance(hashes[i], hashes[j])
|
||||
similarity = 1.0 - (distance / 64.0) # 64 bits total
|
||||
if similarity >= similarity_threshold:
|
||||
is_duplicate = True
|
||||
logging.info(
|
||||
f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping"
|
||||
)
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
keep_indices.append(i)
|
||||
|
||||
# Return only unique images
|
||||
unique_images = [images[i] for i in keep_indices]
|
||||
logging.info(
|
||||
f"Deduplication: kept {len(unique_images)} out of {len(images)} images"
|
||||
)
|
||||
return unique_images
|
||||
|
||||
|
||||
class ImageGridNode(ImageProcessingNode):
|
||||
"""Combine multiple images into a single grid/collage."""
|
||||
|
||||
node_id = "ImageGrid"
|
||||
display_name = "Image Grid"
|
||||
description = "Arrange multiple images into a grid layout."
|
||||
is_group_process = True # Requires full list to create grid
|
||||
is_output_list = False # Outputs single grid image
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"columns",
|
||||
default=4,
|
||||
min=1,
|
||||
max=20,
|
||||
tooltip="Number of columns in the grid.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"cell_width",
|
||||
default=256,
|
||||
min=32,
|
||||
max=2048,
|
||||
tooltip="Width of each cell in the grid.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"cell_height",
|
||||
default=256,
|
||||
min=32,
|
||||
max=2048,
|
||||
tooltip="Height of each cell in the grid.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"padding", default=4, min=0, max=50, tooltip="Padding between images."
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, images, columns, cell_width, cell_height, padding):
|
||||
"""Arrange images into a grid."""
|
||||
if len(images) == 0:
|
||||
raise ValueError("Cannot create grid from empty image list")
|
||||
|
||||
# Calculate grid dimensions
|
||||
num_images = len(images)
|
||||
rows = (num_images + columns - 1) // columns # Ceiling division
|
||||
|
||||
# Calculate total grid size
|
||||
grid_width = columns * cell_width + (columns - 1) * padding
|
||||
grid_height = rows * cell_height + (rows - 1) * padding
|
||||
|
||||
# Create blank grid
|
||||
grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0))
|
||||
|
||||
# Place images
|
||||
for idx, img_tensor in enumerate(images):
|
||||
row = idx // columns
|
||||
col = idx % columns
|
||||
|
||||
# Convert to PIL and resize to cell size
|
||||
img = tensor_to_pil(img_tensor)
|
||||
img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Calculate position
|
||||
x = col * (cell_width + padding)
|
||||
y = row * (cell_height + padding)
|
||||
|
||||
# Paste into grid
|
||||
grid.paste(img, (x, y))
|
||||
|
||||
logging.info(
|
||||
f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})"
|
||||
)
|
||||
return pil_to_tensor(grid)
|
||||
|
||||
|
||||
class MergeImageListsNode(ImageProcessingNode):
|
||||
"""Merge multiple image lists into a single list."""
|
||||
|
||||
node_id = "MergeImageLists"
|
||||
display_name = "Merge Image Lists"
|
||||
description = "Concatenate multiple image lists into one."
|
||||
is_group_process = True # Receives images as list
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, images):
|
||||
"""Simply return the images list (already merged by input handling)."""
|
||||
# When multiple list inputs are connected, they're concatenated
|
||||
# For now, this is a simple pass-through
|
||||
logging.info(f"Merged image list contains {len(images)} images")
|
||||
return images
|
||||
|
||||
|
||||
class MergeTextListsNode(TextProcessingNode):
|
||||
"""Merge multiple text lists into a single list."""
|
||||
|
||||
node_id = "MergeTextLists"
|
||||
display_name = "Merge Text Lists"
|
||||
description = "Concatenate multiple text lists into one."
|
||||
is_group_process = True # Receives texts as list
|
||||
|
||||
@classmethod
|
||||
def _group_process(cls, texts):
|
||||
"""Simply return the texts list (already merged by input handling)."""
|
||||
# When multiple list inputs are connected, they're concatenated
|
||||
# For now, this is a simple pass-through
|
||||
logging.info(f"Merged text list contains {len(texts)} texts")
|
||||
return texts
|
||||
|
||||
|
||||
# ========== Training Dataset Nodes ==========
|
||||
@ -1182,6 +1485,11 @@ class DatasetExtension(ComfyExtension):
|
||||
AddTextSuffixNode,
|
||||
ReplaceTextNode,
|
||||
StripWhitespaceNode,
|
||||
# Group processing examples
|
||||
ImageDeduplicationNode,
|
||||
ImageGridNode,
|
||||
MergeImageListsNode,
|
||||
MergeTextListsNode,
|
||||
# Training dataset nodes
|
||||
MakeTrainingDataset,
|
||||
SaveTrainingDataset,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user