mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 12:20:16 +08:00
use single process instead of input list when no need
This commit is contained in:
parent
9217d54a8c
commit
0318d2d60c
@ -309,29 +309,71 @@ def pil_to_tensor(img):
|
|||||||
|
|
||||||
|
|
||||||
class ImageProcessingNode(io.ComfyNode):
|
class ImageProcessingNode(io.ComfyNode):
|
||||||
"""Base class for image processing nodes that operate on lists of images.
|
"""Base class for image processing nodes that operate on images.
|
||||||
|
|
||||||
Child classes should set:
|
Child classes should set:
|
||||||
node_id: Unique node identifier (required)
|
node_id: Unique node identifier (required)
|
||||||
display_name: Display name (optional, defaults to node_id)
|
display_name: Display name (optional, defaults to node_id)
|
||||||
description: Node description (optional)
|
description: Node description (optional)
|
||||||
extra_inputs: List of additional io.Input objects beyond "images" (optional)
|
extra_inputs: List of additional io.Input objects beyond "images" (optional)
|
||||||
|
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||||
|
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||||
|
|
||||||
Child classes must implement:
|
Child classes must implement ONE of:
|
||||||
_process(cls, images, **kwargs) -> list[tensor]
|
_process(cls, image, **kwargs) -> tensor (for single-item processing)
|
||||||
|
_group_process(cls, images, **kwargs) -> list[tensor] (for group processing)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_id = None
|
node_id = None
|
||||||
display_name = None
|
display_name = None
|
||||||
description = None
|
description = None
|
||||||
extra_inputs = []
|
extra_inputs = []
|
||||||
|
is_group_process = None # None = auto-detect, True/False = explicit
|
||||||
|
is_output_list = True # Configurable output mode
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _detect_processing_mode(cls):
|
||||||
|
"""Detect whether this node uses group or individual processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if group processing, False if individual processing
|
||||||
|
"""
|
||||||
|
# Explicit setting takes precedence
|
||||||
|
if cls.is_group_process is not None:
|
||||||
|
return cls.is_group_process
|
||||||
|
|
||||||
|
# Check which method is overridden
|
||||||
|
base_class = ImageProcessingNode
|
||||||
|
has_process = cls._process is not base_class._process
|
||||||
|
has_group = cls._group_process is not base_class._group_process
|
||||||
|
|
||||||
|
if has_process and has_group:
|
||||||
|
raise ValueError(
|
||||||
|
f"{cls.__name__}: Cannot override both _process and _group_process. "
|
||||||
|
"Override only one, or set is_group_process explicitly."
|
||||||
|
)
|
||||||
|
if not has_process and not has_group:
|
||||||
|
raise ValueError(
|
||||||
|
f"{cls.__name__}: Must override either _process or _group_process"
|
||||||
|
)
|
||||||
|
|
||||||
|
return has_group
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
if cls.node_id is None:
|
if cls.node_id is None:
|
||||||
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
|
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
|
||||||
|
|
||||||
inputs = [io.Image.Input("images", tooltip="List of images to process.")]
|
is_group = cls._detect_processing_mode()
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
io.Image.Input(
|
||||||
|
"images",
|
||||||
|
tooltip=(
|
||||||
|
"List of images to process." if is_group else "Image to process."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
inputs.extend(cls.extra_inputs)
|
inputs.extend(cls.extra_inputs)
|
||||||
|
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
@ -339,12 +381,12 @@ class ImageProcessingNode(io.ComfyNode):
|
|||||||
display_name=cls.display_name or cls.node_id,
|
display_name=cls.display_name or cls.node_id,
|
||||||
category="dataset/image",
|
category="dataset/image",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
is_input_list=True,
|
is_input_list=is_group, # True for group, False for individual
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(
|
io.Image.Output(
|
||||||
display_name="images",
|
display_name="images",
|
||||||
is_output_list=True,
|
is_output_list=cls.is_output_list,
|
||||||
tooltip="Processed images",
|
tooltip="Processed images",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@ -352,8 +394,10 @@ class ImageProcessingNode(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images, **kwargs):
|
def execute(cls, images, **kwargs):
|
||||||
"""Execute the node. Extracts scalar values and calls _process."""
|
"""Execute the node. Routes to _process or _group_process based on mode."""
|
||||||
# Extract scalar values from lists (due to is_input_list=True)
|
is_group = cls._detect_processing_mode()
|
||||||
|
|
||||||
|
# Extract scalar values from lists for parameters
|
||||||
params = {}
|
params = {}
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if isinstance(v, list) and len(v) == 1:
|
if isinstance(v, list) and len(v) == 1:
|
||||||
@ -361,12 +405,37 @@ class ImageProcessingNode(io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
params[k] = v
|
params[k] = v
|
||||||
|
|
||||||
result = cls._process(images, **params)
|
if is_group:
|
||||||
return io.NodeOutput(result)
|
# Group processing: images is list, call _group_process
|
||||||
|
result = cls._group_process(images, **params)
|
||||||
|
else:
|
||||||
|
# Individual processing: images is single item, call _process
|
||||||
|
result = cls._process(images, **params)
|
||||||
|
|
||||||
|
# Wrap result based on is_output_list
|
||||||
|
if cls.is_output_list:
|
||||||
|
# Result should already be a list (or will be for individual)
|
||||||
|
return io.NodeOutput(result if is_group else [result])
|
||||||
|
else:
|
||||||
|
# Single output - wrap in list for NodeOutput
|
||||||
|
return io.NodeOutput([result])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, **kwargs):
|
def _process(cls, image, **kwargs):
|
||||||
"""Override this method in subclasses to implement specific processing.
|
"""Override this method for single-item processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: tensor - Single image tensor
|
||||||
|
**kwargs: Additional parameters (already extracted from lists)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor - Processed image
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _group_process(cls, images, **kwargs):
|
||||||
|
"""Override this method for group processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
images: list[tensor] - List of image tensors
|
images: list[tensor] - List of image tensors
|
||||||
@ -375,33 +444,75 @@ class ImageProcessingNode(io.ComfyNode):
|
|||||||
Returns:
|
Returns:
|
||||||
list[tensor] - Processed images
|
list[tensor] - Processed images
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
raise NotImplementedError(
|
||||||
|
f"{cls.__name__} must implement _group_process method"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TextProcessingNode(io.ComfyNode):
|
class TextProcessingNode(io.ComfyNode):
|
||||||
"""Base class for text processing nodes that operate on lists of texts.
|
"""Base class for text processing nodes that operate on texts.
|
||||||
|
|
||||||
Child classes should set:
|
Child classes should set:
|
||||||
node_id: Unique node identifier (required)
|
node_id: Unique node identifier (required)
|
||||||
display_name: Display name (optional, defaults to node_id)
|
display_name: Display name (optional, defaults to node_id)
|
||||||
description: Node description (optional)
|
description: Node description (optional)
|
||||||
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
|
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
|
||||||
|
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
||||||
|
is_output_list: True (list output) or False (single output) (optional, default True)
|
||||||
|
|
||||||
Child classes must implement:
|
Child classes must implement ONE of:
|
||||||
_process(cls, texts, **kwargs) -> list[str]
|
_process(cls, text, **kwargs) -> str (for single-item processing)
|
||||||
|
_group_process(cls, texts, **kwargs) -> list[str] (for group processing)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_id = None
|
node_id = None
|
||||||
display_name = None
|
display_name = None
|
||||||
description = None
|
description = None
|
||||||
extra_inputs = []
|
extra_inputs = []
|
||||||
|
is_group_process = None # None = auto-detect, True/False = explicit
|
||||||
|
is_output_list = True # Configurable output mode
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _detect_processing_mode(cls):
|
||||||
|
"""Detect whether this node uses group or individual processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if group processing, False if individual processing
|
||||||
|
"""
|
||||||
|
# Explicit setting takes precedence
|
||||||
|
if cls.is_group_process is not None:
|
||||||
|
return cls.is_group_process
|
||||||
|
|
||||||
|
# Check which method is overridden
|
||||||
|
base_class = TextProcessingNode
|
||||||
|
has_process = cls._process is not base_class._process
|
||||||
|
has_group = cls._group_process is not base_class._group_process
|
||||||
|
|
||||||
|
if has_process and has_group:
|
||||||
|
raise ValueError(
|
||||||
|
f"{cls.__name__}: Cannot override both _process and _group_process. "
|
||||||
|
"Override only one, or set is_group_process explicitly."
|
||||||
|
)
|
||||||
|
if not has_process and not has_group:
|
||||||
|
raise ValueError(
|
||||||
|
f"{cls.__name__}: Must override either _process or _group_process"
|
||||||
|
)
|
||||||
|
|
||||||
|
return has_group
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
if cls.node_id is None:
|
if cls.node_id is None:
|
||||||
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
|
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
|
||||||
|
|
||||||
inputs = [io.String.Input("texts", tooltip="List of texts to process.")]
|
is_group = cls._detect_processing_mode()
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
io.String.Input(
|
||||||
|
"texts",
|
||||||
|
tooltip="List of texts to process." if is_group else "Text to process.",
|
||||||
|
)
|
||||||
|
]
|
||||||
inputs.extend(cls.extra_inputs)
|
inputs.extend(cls.extra_inputs)
|
||||||
|
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
@ -409,19 +520,23 @@ class TextProcessingNode(io.ComfyNode):
|
|||||||
display_name=cls.display_name or cls.node_id,
|
display_name=cls.display_name or cls.node_id,
|
||||||
category="dataset/text",
|
category="dataset/text",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
is_input_list=True,
|
is_input_list=is_group, # True for group, False for individual
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=[
|
outputs=[
|
||||||
io.String.Output(
|
io.String.Output(
|
||||||
display_name="texts", is_output_list=True, tooltip="Processed texts"
|
display_name="texts",
|
||||||
|
is_output_list=cls.is_output_list,
|
||||||
|
tooltip="Processed texts",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, texts, **kwargs):
|
def execute(cls, texts, **kwargs):
|
||||||
"""Execute the node. Extracts scalar values and calls _process."""
|
"""Execute the node. Routes to _process or _group_process based on mode."""
|
||||||
# Extract scalar values from lists (due to is_input_list=True)
|
is_group = cls._detect_processing_mode()
|
||||||
|
|
||||||
|
# Extract scalar values from lists for parameters
|
||||||
params = {}
|
params = {}
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if isinstance(v, list) and len(v) == 1:
|
if isinstance(v, list) and len(v) == 1:
|
||||||
@ -429,12 +544,37 @@ class TextProcessingNode(io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
params[k] = v
|
params[k] = v
|
||||||
|
|
||||||
result = cls._process(texts, **params)
|
if is_group:
|
||||||
return io.NodeOutput(result)
|
# Group processing: texts is list, call _group_process
|
||||||
|
result = cls._group_process(texts, **params)
|
||||||
|
else:
|
||||||
|
# Individual processing: texts is single item, call _process
|
||||||
|
result = cls._process(texts, **params)
|
||||||
|
|
||||||
|
# Wrap result based on is_output_list
|
||||||
|
if cls.is_output_list:
|
||||||
|
# Result should already be a list (or will be for individual)
|
||||||
|
return io.NodeOutput(result if is_group else [result])
|
||||||
|
else:
|
||||||
|
# Single output - wrap in list for NodeOutput
|
||||||
|
return io.NodeOutput([result])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, texts, **kwargs):
|
def _process(cls, text, **kwargs):
|
||||||
"""Override this method in subclasses to implement specific processing.
|
"""Override this method for single-item processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: str - Single text string
|
||||||
|
**kwargs: Additional parameters (already extracted from lists)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str - Processed text
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _group_process(cls, texts, **kwargs):
|
||||||
|
"""Override this method for group processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: list[str] - List of text strings
|
texts: list[str] - List of text strings
|
||||||
@ -443,7 +583,9 @@ class TextProcessingNode(io.ComfyNode):
|
|||||||
Returns:
|
Returns:
|
||||||
list[str] - Processed texts
|
list[str] - Processed texts
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
raise NotImplementedError(
|
||||||
|
f"{cls.__name__} must implement _group_process method"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ========== Image Transform Nodes ==========
|
# ========== Image Transform Nodes ==========
|
||||||
@ -465,31 +607,28 @@ class ResizeImagesToSameSizeNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, width, height, mode):
|
def _process(cls, image, width, height, mode):
|
||||||
output_images = []
|
img = tensor_to_pil(image)
|
||||||
for img_tensor in images:
|
|
||||||
img = tensor_to_pil(img_tensor)
|
|
||||||
|
|
||||||
if mode == "stretch":
|
if mode == "stretch":
|
||||||
|
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||||
|
elif mode == "crop_center":
|
||||||
|
left = max(0, (img.width - width) // 2)
|
||||||
|
top = max(0, (img.height - height) // 2)
|
||||||
|
right = min(img.width, left + width)
|
||||||
|
bottom = min(img.height, top + height)
|
||||||
|
img = img.crop((left, top, right, bottom))
|
||||||
|
if img.width != width or img.height != height:
|
||||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||||
elif mode == "crop_center":
|
elif mode == "pad":
|
||||||
left = max(0, (img.width - width) // 2)
|
img.thumbnail((width, height), Image.Resampling.LANCZOS)
|
||||||
top = max(0, (img.height - height) // 2)
|
new_img = Image.new("RGB", (width, height), (0, 0, 0))
|
||||||
right = min(img.width, left + width)
|
paste_x = (width - img.width) // 2
|
||||||
bottom = min(img.height, top + height)
|
paste_y = (height - img.height) // 2
|
||||||
img = img.crop((left, top, right, bottom))
|
new_img.paste(img, (paste_x, paste_y))
|
||||||
if img.width != width or img.height != height:
|
img = new_img
|
||||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
|
||||||
elif mode == "pad":
|
|
||||||
img.thumbnail((width, height), Image.Resampling.LANCZOS)
|
|
||||||
new_img = Image.new("RGB", (width, height), (0, 0, 0))
|
|
||||||
paste_x = (width - img.width) // 2
|
|
||||||
paste_y = (height - img.height) // 2
|
|
||||||
new_img.paste(img, (paste_x, paste_y))
|
|
||||||
img = new_img
|
|
||||||
|
|
||||||
output_images.append(pil_to_tensor(img))
|
return pil_to_tensor(img)
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
||||||
@ -514,18 +653,15 @@ class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, pixel_count, steps):
|
def _process(cls, image, pixel_count, steps):
|
||||||
output_images = []
|
img = tensor_to_pil(image)
|
||||||
for img_tensor in images:
|
w, h = img.size
|
||||||
img = tensor_to_pil(img_tensor)
|
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
|
||||||
w, h = img.size
|
new_w = int(w * pixel_count_ratio / steps) * steps
|
||||||
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
|
new_h = int(h * pixel_count_ratio / steps) * steps
|
||||||
new_w = int(h * pixel_count_ratio / steps) * steps
|
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
|
||||||
new_h = int(w * pixel_count_ratio / steps) * steps
|
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
|
return pil_to_tensor(img)
|
||||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
||||||
output_images.append(pil_to_tensor(img))
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||||
@ -543,20 +679,17 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, shorter_edge):
|
def _process(cls, image, shorter_edge):
|
||||||
output_images = []
|
img = tensor_to_pil(image)
|
||||||
for img_tensor in images:
|
w, h = img.size
|
||||||
img = tensor_to_pil(img_tensor)
|
if w < h:
|
||||||
w, h = img.size
|
new_w = shorter_edge
|
||||||
if w < h:
|
new_h = int(h * (shorter_edge / w))
|
||||||
new_w = shorter_edge
|
else:
|
||||||
new_h = int(h * (shorter_edge / w))
|
new_h = shorter_edge
|
||||||
else:
|
new_w = int(w * (shorter_edge / h))
|
||||||
new_h = shorter_edge
|
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
new_w = int(w * (shorter_edge / h))
|
return pil_to_tensor(img)
|
||||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
||||||
output_images.append(pil_to_tensor(img))
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
||||||
@ -574,20 +707,17 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, longer_edge):
|
def _process(cls, image, longer_edge):
|
||||||
output_images = []
|
img = tensor_to_pil(image)
|
||||||
for img_tensor in images:
|
w, h = img.size
|
||||||
img = tensor_to_pil(img_tensor)
|
if w > h:
|
||||||
w, h = img.size
|
new_w = longer_edge
|
||||||
if w > h:
|
new_h = int(h * (longer_edge / w))
|
||||||
new_w = longer_edge
|
else:
|
||||||
new_h = int(h * (longer_edge / w))
|
new_h = longer_edge
|
||||||
else:
|
new_w = int(w * (longer_edge / h))
|
||||||
new_h = longer_edge
|
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
new_w = int(w * (longer_edge / h))
|
return pil_to_tensor(img)
|
||||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
||||||
output_images.append(pil_to_tensor(img))
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
class CenterCropImagesNode(ImageProcessingNode):
|
class CenterCropImagesNode(ImageProcessingNode):
|
||||||
@ -600,17 +730,14 @@ class CenterCropImagesNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, width, height):
|
def _process(cls, image, width, height):
|
||||||
output_images = []
|
img = tensor_to_pil(image)
|
||||||
for img_tensor in images:
|
left = max(0, (img.width - width) // 2)
|
||||||
img = tensor_to_pil(img_tensor)
|
top = max(0, (img.height - height) // 2)
|
||||||
left = max(0, (img.width - width) // 2)
|
right = min(img.width, left + width)
|
||||||
top = max(0, (img.height - height) // 2)
|
bottom = min(img.height, top + height)
|
||||||
right = min(img.width, left + width)
|
img = img.crop((left, top, right, bottom))
|
||||||
bottom = min(img.height, top + height)
|
return pil_to_tensor(img)
|
||||||
img = img.crop((left, top, right, bottom))
|
|
||||||
output_images.append(pil_to_tensor(img))
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
class RandomCropImagesNode(ImageProcessingNode):
|
class RandomCropImagesNode(ImageProcessingNode):
|
||||||
@ -628,20 +755,17 @@ class RandomCropImagesNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, width, height, seed):
|
def _process(cls, image, width, height, seed):
|
||||||
np.random.seed(seed % (2**32 - 1))
|
np.random.seed(seed % (2**32 - 1))
|
||||||
output_images = []
|
img = tensor_to_pil(image)
|
||||||
for img_tensor in images:
|
max_left = max(0, img.width - width)
|
||||||
img = tensor_to_pil(img_tensor)
|
max_top = max(0, img.height - height)
|
||||||
max_left = max(0, img.width - width)
|
left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
|
||||||
max_top = max(0, img.height - height)
|
top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
|
||||||
left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
|
right = min(img.width, left + width)
|
||||||
top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
|
bottom = min(img.height, top + height)
|
||||||
right = min(img.width, left + width)
|
img = img.crop((left, top, right, bottom))
|
||||||
bottom = min(img.height, top + height)
|
return pil_to_tensor(img)
|
||||||
img = img.crop((left, top, right, bottom))
|
|
||||||
output_images.append(pil_to_tensor(img))
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
class FlipImagesNode(ImageProcessingNode):
|
class FlipImagesNode(ImageProcessingNode):
|
||||||
@ -658,16 +782,13 @@ class FlipImagesNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, direction):
|
def _process(cls, image, direction):
|
||||||
output_images = []
|
img = tensor_to_pil(image)
|
||||||
for img_tensor in images:
|
if direction == "horizontal":
|
||||||
img = tensor_to_pil(img_tensor)
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
if direction == "horizontal":
|
else:
|
||||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
else:
|
return pil_to_tensor(img)
|
||||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
|
||||||
output_images.append(pil_to_tensor(img))
|
|
||||||
return output_images
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizeImagesNode(ImageProcessingNode):
|
class NormalizeImagesNode(ImageProcessingNode):
|
||||||
@ -692,8 +813,8 @@ class NormalizeImagesNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, mean, std):
|
def _process(cls, image, mean, std):
|
||||||
return [(img - mean) / std for img in images]
|
return (image - mean) / std
|
||||||
|
|
||||||
|
|
||||||
class AdjustBrightnessNode(ImageProcessingNode):
|
class AdjustBrightnessNode(ImageProcessingNode):
|
||||||
@ -711,8 +832,8 @@ class AdjustBrightnessNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, factor):
|
def _process(cls, image, factor):
|
||||||
return [(img * factor).clamp(0.0, 1.0) for img in images]
|
return (image * factor).clamp(0.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
class AdjustContrastNode(ImageProcessingNode):
|
class AdjustContrastNode(ImageProcessingNode):
|
||||||
@ -730,14 +851,15 @@ class AdjustContrastNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, factor):
|
def _process(cls, image, factor):
|
||||||
return [((img - 0.5) * factor + 0.5).clamp(0.0, 1.0) for img in images]
|
return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
class ShuffleDatasetNode(ImageProcessingNode):
|
class ShuffleDatasetNode(ImageProcessingNode):
|
||||||
node_id = "ShuffleDataset"
|
node_id = "ShuffleDataset"
|
||||||
display_name = "Shuffle Image Dataset"
|
display_name = "Shuffle Image Dataset"
|
||||||
description = "Randomly shuffle the order of images in the dataset."
|
description = "Randomly shuffle the order of images in the dataset."
|
||||||
|
is_group_process = True # Requires full list to shuffle
|
||||||
extra_inputs = [
|
extra_inputs = [
|
||||||
io.Int.Input(
|
io.Int.Input(
|
||||||
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
||||||
@ -745,7 +867,7 @@ class ShuffleDatasetNode(ImageProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, images, seed):
|
def _group_process(cls, images, seed):
|
||||||
np.random.seed(seed % (2**32 - 1))
|
np.random.seed(seed % (2**32 - 1))
|
||||||
indices = np.random.permutation(len(images))
|
indices = np.random.permutation(len(images))
|
||||||
return [images[i] for i in indices]
|
return [images[i] for i in indices]
|
||||||
@ -804,8 +926,8 @@ class TextToLowercaseNode(TextProcessingNode):
|
|||||||
description = "Convert all texts to lowercase."
|
description = "Convert all texts to lowercase."
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, texts):
|
def _process(cls, text):
|
||||||
return [text.lower() for text in texts]
|
return text.lower()
|
||||||
|
|
||||||
|
|
||||||
class TextToUppercaseNode(TextProcessingNode):
|
class TextToUppercaseNode(TextProcessingNode):
|
||||||
@ -814,8 +936,8 @@ class TextToUppercaseNode(TextProcessingNode):
|
|||||||
description = "Convert all texts to uppercase."
|
description = "Convert all texts to uppercase."
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, texts):
|
def _process(cls, text):
|
||||||
return [text.upper() for text in texts]
|
return text.upper()
|
||||||
|
|
||||||
|
|
||||||
class TruncateTextNode(TextProcessingNode):
|
class TruncateTextNode(TextProcessingNode):
|
||||||
@ -829,8 +951,8 @@ class TruncateTextNode(TextProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, texts, max_length):
|
def _process(cls, text, max_length):
|
||||||
return [text[:max_length] for text in texts]
|
return text[:max_length]
|
||||||
|
|
||||||
|
|
||||||
class AddTextPrefixNode(TextProcessingNode):
|
class AddTextPrefixNode(TextProcessingNode):
|
||||||
@ -842,8 +964,8 @@ class AddTextPrefixNode(TextProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, texts, prefix):
|
def _process(cls, text, prefix):
|
||||||
return [prefix + text for text in texts]
|
return prefix + text
|
||||||
|
|
||||||
|
|
||||||
class AddTextSuffixNode(TextProcessingNode):
|
class AddTextSuffixNode(TextProcessingNode):
|
||||||
@ -855,8 +977,8 @@ class AddTextSuffixNode(TextProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, texts, suffix):
|
def _process(cls, text, suffix):
|
||||||
return [text + suffix for text in texts]
|
return text + suffix
|
||||||
|
|
||||||
|
|
||||||
class ReplaceTextNode(TextProcessingNode):
|
class ReplaceTextNode(TextProcessingNode):
|
||||||
@ -869,8 +991,8 @@ class ReplaceTextNode(TextProcessingNode):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, texts, find, replace):
|
def _process(cls, text, find, replace):
|
||||||
return [text.replace(find, replace) for text in texts]
|
return text.replace(find, replace)
|
||||||
|
|
||||||
|
|
||||||
class StripWhitespaceNode(TextProcessingNode):
|
class StripWhitespaceNode(TextProcessingNode):
|
||||||
@ -879,8 +1001,189 @@ class StripWhitespaceNode(TextProcessingNode):
|
|||||||
description = "Strip leading and trailing whitespace from all texts."
|
description = "Strip leading and trailing whitespace from all texts."
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, texts):
|
def _process(cls, text):
|
||||||
return [text.strip() for text in texts]
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Group Processing Example Nodes ==========
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDeduplicationNode(ImageProcessingNode):
|
||||||
|
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
|
||||||
|
|
||||||
|
node_id = "ImageDeduplication"
|
||||||
|
display_name = "Image Deduplication"
|
||||||
|
description = "Remove duplicate or very similar images from the dataset."
|
||||||
|
is_group_process = True # Requires full list to compare images
|
||||||
|
extra_inputs = [
|
||||||
|
io.Float.Input(
|
||||||
|
"similarity_threshold",
|
||||||
|
default=0.95,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _group_process(cls, images, similarity_threshold):
|
||||||
|
"""Remove duplicate images using perceptual hashing."""
|
||||||
|
if len(images) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Compute simple perceptual hash for each image
|
||||||
|
def compute_hash(img_tensor):
|
||||||
|
"""Compute a simple perceptual hash by resizing to 8x8 and comparing to average."""
|
||||||
|
img = tensor_to_pil(img_tensor)
|
||||||
|
# Resize to 8x8
|
||||||
|
img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L")
|
||||||
|
# Get pixels
|
||||||
|
pixels = list(img_small.getdata())
|
||||||
|
# Compute average
|
||||||
|
avg = sum(pixels) / len(pixels)
|
||||||
|
# Create hash (1 if above average, 0 otherwise)
|
||||||
|
hash_bits = "".join("1" if p > avg else "0" for p in pixels)
|
||||||
|
return hash_bits
|
||||||
|
|
||||||
|
def hamming_distance(hash1, hash2):
|
||||||
|
"""Compute Hamming distance between two hash strings."""
|
||||||
|
return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
|
||||||
|
|
||||||
|
# Compute hashes for all images
|
||||||
|
hashes = [compute_hash(img) for img in images]
|
||||||
|
|
||||||
|
# Find duplicates
|
||||||
|
keep_indices = []
|
||||||
|
for i in range(len(images)):
|
||||||
|
is_duplicate = False
|
||||||
|
for j in keep_indices:
|
||||||
|
# Compare hashes
|
||||||
|
distance = hamming_distance(hashes[i], hashes[j])
|
||||||
|
similarity = 1.0 - (distance / 64.0) # 64 bits total
|
||||||
|
if similarity >= similarity_threshold:
|
||||||
|
is_duplicate = True
|
||||||
|
logging.info(
|
||||||
|
f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not is_duplicate:
|
||||||
|
keep_indices.append(i)
|
||||||
|
|
||||||
|
# Return only unique images
|
||||||
|
unique_images = [images[i] for i in keep_indices]
|
||||||
|
logging.info(
|
||||||
|
f"Deduplication: kept {len(unique_images)} out of {len(images)} images"
|
||||||
|
)
|
||||||
|
return unique_images
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGridNode(ImageProcessingNode):
|
||||||
|
"""Combine multiple images into a single grid/collage."""
|
||||||
|
|
||||||
|
node_id = "ImageGrid"
|
||||||
|
display_name = "Image Grid"
|
||||||
|
description = "Arrange multiple images into a grid layout."
|
||||||
|
is_group_process = True # Requires full list to create grid
|
||||||
|
is_output_list = False # Outputs single grid image
|
||||||
|
extra_inputs = [
|
||||||
|
io.Int.Input(
|
||||||
|
"columns",
|
||||||
|
default=4,
|
||||||
|
min=1,
|
||||||
|
max=20,
|
||||||
|
tooltip="Number of columns in the grid.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"cell_width",
|
||||||
|
default=256,
|
||||||
|
min=32,
|
||||||
|
max=2048,
|
||||||
|
tooltip="Width of each cell in the grid.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"cell_height",
|
||||||
|
default=256,
|
||||||
|
min=32,
|
||||||
|
max=2048,
|
||||||
|
tooltip="Height of each cell in the grid.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"padding", default=4, min=0, max=50, tooltip="Padding between images."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _group_process(cls, images, columns, cell_width, cell_height, padding):
|
||||||
|
"""Arrange images into a grid."""
|
||||||
|
if len(images) == 0:
|
||||||
|
raise ValueError("Cannot create grid from empty image list")
|
||||||
|
|
||||||
|
# Calculate grid dimensions
|
||||||
|
num_images = len(images)
|
||||||
|
rows = (num_images + columns - 1) // columns # Ceiling division
|
||||||
|
|
||||||
|
# Calculate total grid size
|
||||||
|
grid_width = columns * cell_width + (columns - 1) * padding
|
||||||
|
grid_height = rows * cell_height + (rows - 1) * padding
|
||||||
|
|
||||||
|
# Create blank grid
|
||||||
|
grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0))
|
||||||
|
|
||||||
|
# Place images
|
||||||
|
for idx, img_tensor in enumerate(images):
|
||||||
|
row = idx // columns
|
||||||
|
col = idx % columns
|
||||||
|
|
||||||
|
# Convert to PIL and resize to cell size
|
||||||
|
img = tensor_to_pil(img_tensor)
|
||||||
|
img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# Calculate position
|
||||||
|
x = col * (cell_width + padding)
|
||||||
|
y = row * (cell_height + padding)
|
||||||
|
|
||||||
|
# Paste into grid
|
||||||
|
grid.paste(img, (x, y))
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})"
|
||||||
|
)
|
||||||
|
return pil_to_tensor(grid)
|
||||||
|
|
||||||
|
|
||||||
|
class MergeImageListsNode(ImageProcessingNode):
|
||||||
|
"""Merge multiple image lists into a single list."""
|
||||||
|
|
||||||
|
node_id = "MergeImageLists"
|
||||||
|
display_name = "Merge Image Lists"
|
||||||
|
description = "Concatenate multiple image lists into one."
|
||||||
|
is_group_process = True # Receives images as list
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _group_process(cls, images):
|
||||||
|
"""Simply return the images list (already merged by input handling)."""
|
||||||
|
# When multiple list inputs are connected, they're concatenated
|
||||||
|
# For now, this is a simple pass-through
|
||||||
|
logging.info(f"Merged image list contains {len(images)} images")
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
class MergeTextListsNode(TextProcessingNode):
|
||||||
|
"""Merge multiple text lists into a single list."""
|
||||||
|
|
||||||
|
node_id = "MergeTextLists"
|
||||||
|
display_name = "Merge Text Lists"
|
||||||
|
description = "Concatenate multiple text lists into one."
|
||||||
|
is_group_process = True # Receives texts as list
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _group_process(cls, texts):
|
||||||
|
"""Simply return the texts list (already merged by input handling)."""
|
||||||
|
# When multiple list inputs are connected, they're concatenated
|
||||||
|
# For now, this is a simple pass-through
|
||||||
|
logging.info(f"Merged text list contains {len(texts)} texts")
|
||||||
|
return texts
|
||||||
|
|
||||||
|
|
||||||
# ========== Training Dataset Nodes ==========
|
# ========== Training Dataset Nodes ==========
|
||||||
@ -1182,6 +1485,11 @@ class DatasetExtension(ComfyExtension):
|
|||||||
AddTextSuffixNode,
|
AddTextSuffixNode,
|
||||||
ReplaceTextNode,
|
ReplaceTextNode,
|
||||||
StripWhitespaceNode,
|
StripWhitespaceNode,
|
||||||
|
# Group processing examples
|
||||||
|
ImageDeduplicationNode,
|
||||||
|
ImageGridNode,
|
||||||
|
MergeImageListsNode,
|
||||||
|
MergeTextListsNode,
|
||||||
# Training dataset nodes
|
# Training dataset nodes
|
||||||
MakeTrainingDataset,
|
MakeTrainingDataset,
|
||||||
SaveTrainingDataset,
|
SaveTrainingDataset,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user