ComfyUI/comfy_extras/nodes_dataset.py

2034 lines
74 KiB
Python

import logging
import os
import json
import pickle
import struct
import numpy as np
import safetensors.torch
import torch
from PIL import Image
from safetensors import safe_open
from typing_extensions import override
import folder_paths
import node_helpers
from comfy_api.latest import ComfyExtension, io
def load_and_process_images(image_files, input_dir):
"""Utility function to load and process a list of images.
Args:
image_files: List of image filenames
input_dir: Base directory containing the images
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
Returns:
torch.Tensor: Batch of processed images
"""
if not image_files:
raise ValueError("No valid images found in input")
output_images = []
for file in image_files:
image_path = os.path.join(input_dir, file)
img = node_helpers.pillow(Image.open, image_path)
if img.mode == "I":
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)
return output_images
class LoadImageDataSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadImageDataSetFromFolder",
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
display_name="Load Image (from Folder)",
category="image",
description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder",
options=folder_paths.get_input_subfolders(),
tooltip="The folder to load images from.",
)
],
outputs=[
io.Image.Output(
display_name="images",
is_output_list=True,
tooltip="List of loaded images",
)
],
)
@classmethod
def execute(cls, folder):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = [
f
for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, sub_input_dir)
return io.NodeOutput(output_tensor)
class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadImageTextDataSetFromFolder",
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
display_name="Load Image-Text (from Folder)",
category="image",
description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder",
options=folder_paths.get_input_subfolders(),
tooltip="The folder to load images and text captions from.",
)
],
outputs=[
io.Image.Output(
display_name="images",
is_output_list=True,
tooltip="List of loaded images",
),
io.String.Output(
display_name="texts",
is_output_list=True,
tooltip="List of text captions",
),
],
)
@classmethod
def execute(cls, folder):
logging.info(f"Loading images from folder: {folder}")
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = []
for item in os.listdir(sub_input_dir):
path = os.path.join(sub_input_dir, item)
if any(item.lower().endswith(ext) for ext in valid_extensions):
image_files.append(path)
elif os.path.isdir(path):
# Support kohya-ss/sd-scripts folder structure
repeat = 1
if item.split("_")[0].isdigit():
repeat = int(item.split("_")[0])
image_files.extend(
[
os.path.join(path, f)
for f in os.listdir(path)
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
* repeat
)
caption_file_path = [
f.replace(os.path.splitext(f)[1], ".txt") for f in image_files
]
captions = []
for caption_file in caption_file_path:
caption_path = os.path.join(sub_input_dir, caption_file)
if os.path.exists(caption_path):
with open(caption_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
captions.append(caption)
else:
captions.append("")
output_tensor = load_and_process_images(image_files, sub_input_dir)
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
return io.NodeOutput(output_tensor, captions)
def save_images_to_folder(image_list, output_dir, prefix="image", overwrite=True):
"""Utility function to save a list of image tensors to disk.
Args:
image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W])
output_dir: Directory to save images to
prefix: Filename prefix
Returns:
List of saved filenames
"""
os.makedirs(output_dir, exist_ok=True)
saved_files = []
for idx, img_tensor in enumerate(image_list):
# Handle different tensor shapes
if isinstance(img_tensor, torch.Tensor):
# Remove batch dimension if present [1, H, W, C] -> [H, W, C]
if img_tensor.dim() == 4 and img_tensor.shape[0] == 1:
img_tensor = img_tensor.squeeze(0)
# If tensor is [C, H, W], permute to [H, W, C]
if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]:
if (
img_tensor.shape[0] <= 4
and img_tensor.shape[1] > 4
and img_tensor.shape[2] > 4
):
img_tensor = img_tensor.permute(1, 2, 0)
# Convert to numpy and scale to 0-255
img_array = img_tensor.cpu().numpy()
img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8)
# Convert to PIL Image
img = Image.fromarray(img_array)
else:
raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}")
# Save image
if overwrite:
filename = f"{prefix}_{idx:05d}.png"
else:
_, _, counter, _, resolved_prefix = folder_paths.get_save_image_path(prefix, output_dir)
filename = f"{resolved_prefix}_{counter:05}_{idx:05d}.png"
filepath = os.path.join(output_dir, filename)
img.save(filepath)
saved_files.append(filename)
return saved_files
class SaveImageDataSetToFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveImageDataSetToFolder",
search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"],
display_name="Save Image (to Folder) (DEPRECATED)",
category="image",
description="Save a dataset of images to a specified folder. Supported formats: PNG.",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive images as list
inputs=[
io.Image.Input("images", tooltip="List of images to save."),
io.String.Input(
"folder_name",
default="dataset",
tooltip="Name of the folder to save images to (inside output directory).",
),
io.String.Input(
"filename_prefix",
default="image",
tooltip="Prefix for saved image filenames.",
advanced=True,
),
io.Combo.Input(
"mode",
default="overwrite",
options=["overwrite", "increment"],
tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting."
),
],
outputs=[],
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
)
@classmethod
def execute(cls, images, folder_name, filename_prefix, mode):
# Extract scalar values
folder_name = folder_name[0]
filename_prefix = filename_prefix[0]
mode = mode[0]
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite')
logging.info(f"Saved {len(saved_files)} images to {output_dir}.")
return io.NodeOutput()
class SaveImageTextDataSetToFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveImageTextDataSetToFolder",
search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"],
display_name="Save Image-Text (to Folder)",
category="image",
description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive both images and texts as lists
inputs=[
io.Image.Input("images", tooltip="List of images to save."),
io.String.Input("texts",
optional=True,
force_input=True,
tooltip="List of text captions to save."
),
io.String.Input(
"folder_name",
default="dataset",
tooltip="Name of the folder to save images to (inside output directory).",
),
io.String.Input(
"filename_prefix",
default="image",
tooltip="Prefix for saved image filenames.",
advanced=True,
),
io.Combo.Input(
"mode",
default="overwrite",
options=["overwrite", "increment"],
tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting."
),
],
outputs=[],
)
@classmethod
def execute(cls, images, folder_name, filename_prefix, mode, texts=None):
# Extract scalar values
folder_name = folder_name[0]
filename_prefix = filename_prefix[0]
mode = mode[0]
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite')
# Save captions
if texts:
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
caption_filename = filename.replace(".png", ".txt")
caption_path = os.path.join(output_dir, caption_filename)
with open(caption_path, "w", encoding="utf-8") as f:
f.write(caption)
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
return io.NodeOutput()
# ========== Helper Functions for Transform Nodes ==========
def tensor_to_pil(img_tensor):
"""Convert tensor to PIL Image."""
if img_tensor.dim() == 4 and img_tensor.shape[0] == 1:
img_tensor = img_tensor.squeeze(0)
img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(img_array)
def pil_to_tensor(img):
"""Convert PIL Image to tensor."""
img_array = np.array(img).astype(np.float32) / 255.0
return torch.from_numpy(img_array)[None,]
# ========== Base Classes for Transform Nodes ==========
class ImageProcessingNode(io.ComfyNode):
"""Base class for image processing nodes that operate on images.
Child classes should set:
node_id: Unique node identifier (required)
search_aliases: List of search aliases (optional)
display_name: Display name (optional, defaults to node_id)
description: Node description (optional)
extra_inputs: List of additional io.Input objects beyond "images" (optional)
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
is_output_list: True (list output) or False (single output) (optional, default True)
is_deprecated: True if the node is deprecated (optional, default False)
Child classes must implement ONE of:
_process(cls, image, **kwargs) -> tensor (for single-item processing)
_group_process(cls, images, **kwargs) -> list[tensor] (for group processing)
"""
node_id = None
search_aliases = []
display_name = None
description = None
extra_inputs = []
is_group_process = None # None = auto-detect, True/False = explicit
is_output_list = None # None = auto-detect based on processing mode
is_deprecated = False
@classmethod
def _detect_processing_mode(cls):
"""Detect whether this node uses group or individual processing.
Returns:
bool: True if group processing, False if individual processing
"""
# Explicit setting takes precedence
if cls.is_group_process is not None:
return cls.is_group_process
# Check which method is overridden by looking at the defining class in MRO
base_class = ImageProcessingNode
# Find which class in MRO defines _process
process_definer = None
for klass in cls.__mro__:
if "_process" in klass.__dict__:
process_definer = klass
break
# Find which class in MRO defines _group_process
group_definer = None
for klass in cls.__mro__:
if "_group_process" in klass.__dict__:
group_definer = klass
break
# Check what was overridden (not defined in base class)
has_process = process_definer is not None and process_definer is not base_class
has_group = group_definer is not None and group_definer is not base_class
if has_process and has_group:
raise ValueError(
f"{cls.__name__}: Cannot override both _process and _group_process. "
"Override only one, or set is_group_process explicitly."
)
if not has_process and not has_group:
raise ValueError(
f"{cls.__name__}: Must override either _process or _group_process"
)
return has_group
@classmethod
def _ensure_image_list(cls, images):
"""Normalize to a flat list of [1, H, W, C] tensors."""
if isinstance(images, torch.Tensor):
if images.ndim != 4:
raise ValueError(f"Expected 4D image tensor, got shape {tuple(images.shape)}")
return [images[i:i+1] for i in range(images.shape[0])]
flat = []
for item in images:
if not isinstance(item, torch.Tensor) or item.ndim != 4:
raise ValueError(f"Expected 4D image tensor, got {type(item).__name__} shape {getattr(item, 'shape', None)}")
flat.extend([item[i:i+1] for i in range(item.shape[0])])
return flat
@classmethod
def define_schema(cls):
if cls.node_id is None:
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
is_group = cls._detect_processing_mode()
# Auto-detect is_output_list if not explicitly set
# Single processing: False (backend collects results into list)
# Group processing: True by default (can be False for single-output nodes)
output_is_list = (
cls.is_output_list if cls.is_output_list is not None else is_group
)
inputs = [
io.Image.Input(
"images",
tooltip=(
"List of images to process." if is_group else "Image to process."
),
)
]
inputs.extend(cls.extra_inputs)
return io.Schema(
node_id=cls.node_id,
search_aliases=cls.search_aliases,
display_name=cls.display_name or cls.node_id,
category=cls.category,
description=cls.description,
is_experimental=True,
is_input_list=is_group, # True for group, False for individual
inputs=inputs,
outputs=[
io.Image.Output(
display_name="images",
is_output_list=output_is_list,
tooltip="Processed images",
)
],
)
@classmethod
def execute(cls, images, **kwargs):
"""Execute the node. Routes to _process or _group_process based on mode."""
is_group = cls._detect_processing_mode()
if is_group:
images = cls._ensure_image_list(images)
# Extract scalar values from lists for parameters
params = {}
for k, v in kwargs.items():
if isinstance(v, list) and len(v) == 1:
params[k] = v[0]
else:
params[k] = v
if is_group:
# Group processing: images is list, call _group_process
result = cls._group_process(images, **params)
else:
# Individual processing: images is single item, call _process
result = cls._process(images, **params)
return io.NodeOutput(result)
@classmethod
def _process(cls, image, **kwargs):
"""Override this method for single-item processing.
Args:
image: tensor - Single image tensor
**kwargs: Additional parameters (already extracted from lists)
Returns:
tensor - Processed image
"""
raise NotImplementedError(f"{cls.__name__} must implement _process method")
@classmethod
def _group_process(cls, images, **kwargs):
"""Override this method for group processing.
Args:
images: list[tensor] - List of image tensors
**kwargs: Additional parameters (already extracted from lists)
Returns:
list[tensor] - Processed images
"""
raise NotImplementedError(
f"{cls.__name__} must implement _group_process method"
)
class TextProcessingNode(io.ComfyNode):
"""Base class for text processing nodes that operate on texts.
Child classes should set:
node_id: Unique node identifier (required)
search_aliases: List of search aliases (optional)
display_name: Display name (optional, defaults to node_id)
description: Node description (optional)
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
is_output_list: True (list output) or False (single output) (optional, default True)
is_deprecated: True if the node is deprecated (optional, default False)
Child classes must implement ONE of:
_process(cls, text, **kwargs) -> str (for single-item processing)
_group_process(cls, texts, **kwargs) -> list[str] (for group processing)
"""
node_id = None
search_aliases = []
display_name = None
description = None
extra_inputs = []
is_group_process = None # None = auto-detect, True/False = explicit
is_output_list = None # None = auto-detect based on processing mode
is_deprecated = False
@classmethod
def _detect_processing_mode(cls):
"""Detect whether this node uses group or individual processing.
Returns:
bool: True if group processing, False if individual processing
"""
# Explicit setting takes precedence
if cls.is_group_process is not None:
return cls.is_group_process
# Check which method is overridden by looking at the defining class in MRO
base_class = TextProcessingNode
# Find which class in MRO defines _process
process_definer = None
for klass in cls.__mro__:
if "_process" in klass.__dict__:
process_definer = klass
break
# Find which class in MRO defines _group_process
group_definer = None
for klass in cls.__mro__:
if "_group_process" in klass.__dict__:
group_definer = klass
break
# Check what was overridden (not defined in base class)
has_process = process_definer is not None and process_definer is not base_class
has_group = group_definer is not None and group_definer is not base_class
if has_process and has_group:
raise ValueError(
f"{cls.__name__}: Cannot override both _process and _group_process. "
"Override only one, or set is_group_process explicitly."
)
if not has_process and not has_group:
raise ValueError(
f"{cls.__name__}: Must override either _process or _group_process"
)
return has_group
@classmethod
def define_schema(cls):
if cls.node_id is None:
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
is_group = cls._detect_processing_mode()
inputs = [
io.String.Input(
"texts",
tooltip="List of texts to process." if is_group else "Text to process.",
)
]
inputs.extend(cls.extra_inputs)
return io.Schema(
node_id=cls.node_id,
display_name=cls.display_name or cls.node_id,
category="text",
is_experimental=True,
is_input_list=is_group, # True for group, False for individual
inputs=inputs,
outputs=[
io.String.Output(
display_name="texts",
is_output_list=cls.is_output_list,
tooltip="Processed texts",
)
],
)
@classmethod
def execute(cls, texts, **kwargs):
"""Execute the node. Routes to _process or _group_process based on mode."""
is_group = cls._detect_processing_mode()
# Extract scalar values from lists for parameters
params = {}
for k, v in kwargs.items():
if isinstance(v, list) and len(v) == 1:
params[k] = v[0]
else:
params[k] = v
if is_group:
# Group processing: texts is list, call _group_process
result = cls._group_process(texts, **params)
else:
# Individual processing: texts is single item, call _process
result = cls._process(texts, **params)
# Wrap result based on is_output_list
if cls.is_output_list:
# Result should already be a list (or will be for individual)
return io.NodeOutput(result if is_group else [result])
else:
# Single output - wrap in list for NodeOutput
return io.NodeOutput([result])
@classmethod
def _process(cls, text, **kwargs):
"""Override this method for single-item processing.
Args:
text: str - Single text string
**kwargs: Additional parameters (already extracted from lists)
Returns:
str - Processed text
"""
raise NotImplementedError(f"{cls.__name__} must implement _process method")
@classmethod
def _group_process(cls, texts, **kwargs):
"""Override this method for group processing.
Args:
texts: list[str] - List of text strings
**kwargs: Additional parameters (already extracted from lists)
Returns:
list[str] - Processed texts
"""
raise NotImplementedError(
f"{cls.__name__} must implement _group_process method"
)
# ========== Image Transform Nodes ==========
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByShorterEdge"
display_name = "Resize Images by Shorter Edge (DEPRECATED)"
category = "image/transform"
description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio."
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension
extra_inputs = [
io.Int.Input(
"shorter_edge",
default=512,
min=1,
max=8192,
tooltip="Target dimension for the shorter edge.",
),
]
@classmethod
def _process(cls, image, shorter_edge):
img = tensor_to_pil(image)
w, h = img.size
if w < h:
new_w = shorter_edge
new_h = int(h * (shorter_edge / w))
else:
new_h = shorter_edge
new_w = int(w * (shorter_edge / h))
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
return pil_to_tensor(img)
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByLongerEdge"
display_name = "Resize Images by Longer Edge (DEPRECATED)"
category = "image/transform"
description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio."
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension
extra_inputs = [
io.Int.Input(
"longer_edge",
default=1024,
min=1,
max=8192,
tooltip="Target dimension for the longer edge.",
),
]
@classmethod
def _process(cls, image, longer_edge):
resized_images = []
for image_i in image:
img = tensor_to_pil(image_i)
w, h = img.size
if w > h:
new_w = longer_edge
new_h = int(h * (longer_edge / w))
else:
new_h = longer_edge
new_w = int(w * (longer_edge / h))
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
resized_images.append(pil_to_tensor(img))
return torch.cat(resized_images, dim=0)
class CenterCropImagesNode(ImageProcessingNode):
node_id = "CenterCropImages"
search_aliases=["crop", "cut", "trim"]
display_name="Crop Image (Center)"
category="image/transform"
description = "Center crop an image to the specified dimensions."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
]
@classmethod
def _process(cls, image, width, height):
img = tensor_to_pil(image)
left = max(0, (img.width - width) // 2)
top = max(0, (img.height - height) // 2)
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
return pil_to_tensor(img)
class RandomCropImagesNode(ImageProcessingNode):
node_id = "RandomCropImages"
search_aliases=["crop", "cut", "trim"]
display_name = "Crop Image (Random)"
category="image/transform"
description = "Randomly crop an image to the specified dimensions."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
io.Int.Input(
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
),
]
@classmethod
def _process(cls, image, width, height, seed):
np.random.seed(seed % (2**32 - 1))
img = tensor_to_pil(image)
max_left = max(0, img.width - width)
max_top = max(0, img.height - height)
left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
return pil_to_tensor(img)
class NormalizeImagesNode(ImageProcessingNode):
node_id = "NormalizeImages"
search_aliases=["normalize", "normalize colors"]
display_name = "Normalize Image Colors"
category = "image/color"
description = "Normalize images using mean and standard deviation."
extra_inputs = [
io.Float.Input(
"mean",
default=0.5,
min=0.0,
max=1.0,
tooltip="Mean value for normalization.",
advanced=True,
),
io.Float.Input(
"std",
default=0.5,
min=0.001,
max=1.0,
tooltip="Standard deviation for normalization.",
advanced=True,
),
]
@classmethod
def _process(cls, image, mean, std):
return (image - mean) / std
class AdjustBrightnessNode(ImageProcessingNode):
node_id = "AdjustBrightness"
search_aliases=["brightness"]
display_name = "Adjust Brightness"
category="image/adjustments"
description = "Adjust the brightness of an image."
extra_inputs = [
io.Float.Input(
"factor",
default=1.0,
min=0.0,
max=2.0,
tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.",
),
]
@classmethod
def _process(cls, image, factor):
return (image * factor).clamp(0.0, 1.0)
class AdjustContrastNode(ImageProcessingNode):
node_id = "AdjustContrast"
search_aliases=["contrast"]
display_name = "Adjust Contrast"
category="image/adjustments"
description = "Adjust the contrast of an image."
extra_inputs = [
io.Float.Input(
"factor",
default=1.0,
min=0.0,
max=2.0,
tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.",
),
]
@classmethod
def _process(cls, image, factor):
return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0)
class ShuffleDatasetNode(ImageProcessingNode):
node_id = "ShuffleDataset"
search_aliases=["shuffle", "randomize", "mix"]
display_name = "Shuffle Images List"
category = "image/batch"
description = "Randomly shuffle the order of images in a list."
is_group_process = True # Requires full list to shuffle
extra_inputs = [
io.Int.Input(
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
),
]
@classmethod
def _group_process(cls, images, seed):
np.random.seed(seed % (2**32 - 1))
indices = np.random.permutation(len(images))
return [images[i] for i in indices]
class ShuffleImageTextDatasetNode(io.ComfyNode):
"""Special node that shuffles both images and texts together."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ShuffleImageTextDataset",
search_aliases=["shuffle", "randomize", "mix"],
display_name = "Shuffle Pairs of Image-Text",
category = "image/batch",
description = "Randomly shuffle the order of pairs of image-text in a list.",
is_experimental=True,
is_input_list=True,
inputs=[
io.Image.Input("images", tooltip="List of images to shuffle."),
io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True),
io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
tooltip="Random seed.",
),
],
outputs=[
io.Image.Output(
display_name="images",
is_output_list=True,
tooltip="Shuffled images",
),
io.String.Output(
display_name="texts", is_output_list=True, tooltip="Shuffled texts"
),
],
)
@classmethod
def execute(cls, images, texts, seed):
seed = seed[0] # Extract scalar
np.random.seed(seed % (2**32 - 1))
indices = np.random.permutation(len(images))
shuffled_images = [images[i] for i in indices]
shuffled_texts = [texts[i] for i in indices]
return io.NodeOutput(shuffled_images, shuffled_texts)
# ========== Text Transform Nodes ==========
class TextToLowercaseNode(TextProcessingNode):
node_id = "TextToLowercase"
search_aliases=["lowercase"]
display_name = "Convert Text to Lowercase (DEPRECATED)"
category = "text"
description = "Convert text to lowercase."
is_deprecated = True # This node is superseded by the Convert Text Case node
@classmethod
def _process(cls, text):
return text.lower()
class TextToUppercaseNode(TextProcessingNode):
node_id = "TextToUppercase"
search_aliases=["uppercase"]
display_name = "Convert Text to Uppercase (DEPRECATED)"
category = "text"
description = "Convert text to uppercase."
is_deprecated = True # This node is superseded by the Convert Text Case node
@classmethod
def _process(cls, text):
return text.upper()
class TruncateTextNode(TextProcessingNode):
node_id = "TruncateText"
search_aliases=["truncate", "cut", "shorten"]
display_name = "Truncate Text"
category = "text"
description = "Truncate text to a maximum length."
extra_inputs = [
io.Int.Input(
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
),
]
@classmethod
def _process(cls, text, max_length):
return text[:max_length]
class AddTextPrefixNode(TextProcessingNode):
node_id = "AddTextPrefix"
display_name = "Add Text Prefix (DEPRECATED)"
category = "text"
description = "Add a prefix to all texts."
is_deprecated = True # This node is superseded by the Concatenate Text node
extra_inputs = [
io.String.Input("prefix", default="", tooltip="Prefix to add."),
]
@classmethod
def _process(cls, text, prefix):
return prefix + text
class AddTextSuffixNode(TextProcessingNode):
node_id = "AddTextSuffix"
display_name = "Add Text Suffix (DEPRECATED)"
category = "text"
description = "Add a suffix to all texts."
is_deprecated = True # This node is superseded by the Concatenate Text node
extra_inputs = [
io.String.Input("suffix", default="", tooltip="Suffix to add."),
]
@classmethod
def _process(cls, text, suffix):
return text + suffix
class ReplaceTextNode(TextProcessingNode):
node_id = "ReplaceText"
display_name = "Replace Text (DEPRECATED)"
category = "text"
description = "Replace text in all texts."
is_deprecated = True # This node is superseded by the other Replace Text node
extra_inputs = [
io.String.Input("find", default="", tooltip="Text to find."),
io.String.Input("replace", default="", tooltip="Text to replace with."),
]
@classmethod
def _process(cls, text, find, replace):
return text.replace(find, replace)
class StripWhitespaceNode(TextProcessingNode):
node_id = "StripWhitespace"
display_name = "Strip Whitespace (DEPRECATED)"
category = "text"
description = "Strip leading and trailing whitespace from all texts."
is_deprecated = True # This node is superseded by the Trim Text node
@classmethod
def _process(cls, text):
return text.strip()
# ========== Group Processing Example Nodes ==========
class ImageDeduplicationNode(ImageProcessingNode):
"""Remove duplicate or very similar images from a list using perceptual hashing."""
node_id = "ImageDeduplication"
search_aliases=["deduplicate", "remove duplicates", "similarity filter"]
display_name = "Deduplicate Images"
category = "image/batch"
description = "Remove duplicate or very similar images from a list."
is_group_process = True # Requires full list to compare images
extra_inputs = [
io.Float.Input(
"similarity_threshold",
default=0.95,
min=0.0,
max=1.0,
tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.",
advanced=True,
),
]
@classmethod
def _group_process(cls, images, similarity_threshold):
"""Remove duplicate images using perceptual hashing."""
if len(images) == 0:
return []
# Compute simple perceptual hash for each image
def compute_hash(img_tensor):
"""Compute a simple perceptual hash by resizing to 8x8 and comparing to average."""
img = tensor_to_pil(img_tensor)
# Resize to 8x8
img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L")
# Get pixels
pixels = list(img_small.getdata())
# Compute average
avg = sum(pixels) / len(pixels)
# Create hash (1 if above average, 0 otherwise)
hash_bits = "".join("1" if p > avg else "0" for p in pixels)
return hash_bits
def hamming_distance(hash1, hash2):
"""Compute Hamming distance between two hash strings."""
return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
# Compute hashes for all images
hashes = [compute_hash(img) for img in images]
# Find duplicates
keep_indices = []
for i in range(len(images)):
is_duplicate = False
for j in keep_indices:
# Compare hashes
distance = hamming_distance(hashes[i], hashes[j])
similarity = 1.0 - (distance / 64.0) # 64 bits total
if similarity >= similarity_threshold:
is_duplicate = True
logging.info(
f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping"
)
break
if not is_duplicate:
keep_indices.append(i)
# Return only unique images
unique_images = [images[i] for i in keep_indices]
logging.info(
f"Deduplication: kept {len(unique_images)} out of {len(images)} images"
)
return unique_images
class ImageGridNode(ImageProcessingNode):
"""Combine multiple images into a single grid/collage."""
node_id = "ImageGrid"
search_aliases=["grid", "collage", "combine"]
display_name = "Make Image Grid"
category="image/batch"
description = "Arrange multiple images into a grid layout."
is_group_process = True # Requires full list to create grid
is_output_list = False # Outputs single grid image
extra_inputs = [
io.Int.Input(
"columns",
default=4,
min=1,
max=20,
tooltip="Number of columns in the grid.",
),
io.Int.Input(
"cell_width",
default=256,
min=32,
max=2048,
tooltip="Width of each cell in the grid.",
advanced=True,
),
io.Int.Input(
"cell_height",
default=256,
min=32,
max=2048,
tooltip="Height of each cell in the grid.",
advanced=True,
),
io.Int.Input(
"padding", default=4, min=0, max=50, tooltip="Padding between images.", advanced=True
),
]
@classmethod
def _group_process(cls, images, columns, cell_width, cell_height, padding):
"""Arrange images into a grid."""
if len(images) == 0:
raise ValueError("Cannot create grid from empty image list")
# Calculate grid dimensions
num_images = len(images)
rows = (num_images + columns - 1) // columns # Ceiling division
# Calculate total grid size
grid_width = columns * cell_width + (columns - 1) * padding
grid_height = rows * cell_height + (rows - 1) * padding
# Create blank grid
grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0))
# Place images
for idx, img_tensor in enumerate(images):
row = idx // columns
col = idx % columns
# Convert to PIL and resize to cell size
img = tensor_to_pil(img_tensor)
img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS)
# Calculate position
x = col * (cell_width + padding)
y = row * (cell_height + padding)
# Paste into grid
grid.paste(img, (x, y))
logging.info(
f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})"
)
return pil_to_tensor(grid)
class MergeImageListsNode(ImageProcessingNode):
"""Merge multiple image lists into a single list."""
node_id = "MergeImageLists"
search_aliases=["list", "merge list", "make list"]
display_name = "Merge Image Lists (DEPRECATED)"
category = "image/batch"
description = "Concatenate multiple image lists into one."
is_group_process = True # Receives images as list
is_deprecated = True # This node is superseded by the Create List node
@classmethod
def _group_process(cls, images):
"""Simply return the images list (already merged by input handling)."""
# When multiple list inputs are connected, they're concatenated
# For now, this is a simple pass-through
logging.info(f"Merged image list contains {len(images)} images")
return images
class MergeTextListsNode(TextProcessingNode):
"""Merge multiple text lists into a single list."""
node_id = "MergeTextLists"
display_name = "Merge Text Lists (DEPRECATED)"
category = "text"
description = "Concatenate multiple text lists into one."
is_group_process = True # Receives texts as list
is_deprecated = True # This node is superseded by the Create List node
@classmethod
def _group_process(cls, texts):
"""Simply return the texts list (already merged by input handling)."""
# When multiple list inputs are connected, they're concatenated
# For now, this is a simple pass-through
logging.info(f"Merged text list contains {len(texts)} texts")
return texts
# ========== Training Dataset Nodes ==========
# Sentinel key used in the "skeleton" to mark where a tensor lived in the
# original nested structure. The skeleton is pickled; tensors live in the
# accompanying safetensors file under the referenced key.
_TREF_KEY = "__tref__"
def _split_tensors(obj, out_tensors, prefix):
"""Walk obj recursively. Pull tensors out into out_tensors (keyed by f"{prefix}_{N}")
and return a "skeleton" with the same structure but each tensor replaced by
{"__tref__": key}. Everything that isn't a tensor / dict / list / tuple
(Hook objects, floats, strings, custom extension types, ...) passes through
untouched and will be handled by pickle.
"""
if isinstance(obj, torch.Tensor):
key = f"{prefix}_{len(out_tensors)}"
out_tensors[key] = obj.detach().cpu().clone()
return {_TREF_KEY: key}
elif isinstance(obj, dict):
return {k: _split_tensors(v, out_tensors, prefix) for k, v in obj.items()}
elif isinstance(obj, list):
return [_split_tensors(v, out_tensors, prefix) for v in obj]
elif isinstance(obj, tuple):
return tuple(_split_tensors(v, out_tensors, prefix) for v in obj)
return obj
def _rejoin_tensors(obj, tensor_getter):
"""Inverse of _split_tensors. Walk skeleton, fetch tensors via tensor_getter(key)
wherever a {"__tref__": ...} marker appears.
"""
if isinstance(obj, dict):
if len(obj) == 1 and _TREF_KEY in obj:
return tensor_getter(obj[_TREF_KEY])
return {k: _rejoin_tensors(v, tensor_getter) for k, v in obj.items()}
if isinstance(obj, list):
return [_rejoin_tensors(v, tensor_getter) for v in obj]
if isinstance(obj, tuple):
return tuple(_rejoin_tensors(v, tensor_getter) for v in obj)
return obj
# safetensors dtype strings -> torch dtype, used to read shapes/dtypes from the
# header without loading any tensor bytes.
_ST_STR_TO_DTYPE = {
"F64": torch.float64, "F32": torch.float32, "F16": torch.float16,
"BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32,
"I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool,
}
def _read_safetensors_header(path):
"""Read the safetensors header (dtype + shape per tensor key) without reading
any tensor data. The file starts with an 8-byte little-endian header length
followed by that many bytes of JSON."""
with open(path, "rb") as f:
n = struct.unpack("<Q", f.read(8))[0]
header = json.loads(f.read(n))
header.pop("__metadata__", None)
return header
class RealizeRequired(RuntimeError):
"""Raised when lazy on-disk dataset data is used where real tensors are
needed. Realize it first: .realize() in code, or the Realize Lazy Latents /
Realize Lazy Conditionings nodes in a workflow."""
def _need_realize(self, *args, **kwargs):
raise RealizeRequired(
f"{type(self).__name__} is lazy on-disk data and does not support this "
f"operation. Realize it first (.realize() or a Realize node)."
)
class LazyTensorInfo:
"""Shape/dtype of one on-disk tensor, read from the safetensors header — no
tensor bytes. Anything beyond .shape/.dtype/.ndim raises RealizeRequired."""
def __init__(self, shape, dtype):
self.shape = torch.Size(shape)
self.dtype = dtype
self.ndim = len(self.shape)
def __repr__(self):
return f"LazyTensorInfo(shape={tuple(self.shape)}, dtype={self.dtype})"
__getattr__ = _need_realize
class LazyLatent:
"""One dataset sample's latent dict ({"samples": tensor, ...}) on disk.
Carries the sample's skeleton, so latent["samples"] serves shape/dtype from
the safetensors header with zero I/O. Tensor values require realization:
realize() -> real latent dict, realize_samples() -> real "samples" tensor.
Realization is never cached; a persistent list[LazyLatent] stays near-zero
RAM (the OS page cache handles re-read locality).
"""
def __init__(self, reader, skeleton):
self._reader = reader
self._skel = skeleton
def __getitem__(self, name):
v = self._skel[name]
if isinstance(v, dict) and len(v) == 1 and _TREF_KEY in v:
key = v[_TREF_KEY]
return LazyTensorInfo(self._reader.shape(key), self._reader.dtype(key))
return v # plain non-tensor value (e.g. batch_index)
def realize(self):
"""Read this sample's tensors from disk; return the real latent dict."""
return _rejoin_tensors(self._skel, self._reader.get_tensor)
def realize_samples(self):
"""Read and return just the real "samples" tensor."""
return self._reader.get_tensor(self._skel["samples"][_TREF_KEY])
def __repr__(self):
info = self["samples"]
return f"LazyLatent(samples={tuple(info.shape)}, dtype={info.dtype})"
class LazyConditioning:
"""One dataset sample's conditioning on disk. Content is an arbitrary pickled
structure, so the only access is realize() -> list of [tensor, dict] entries."""
def __init__(self, reader, skeleton):
self._reader = reader
self._skel = skeleton
def realize(self):
"""Read the full conditioning for this sample from disk."""
return _rejoin_tensors(self._skel, self._reader.get_tensor)
realize_entries = realize # a realized conditioning IS its entry list
def __repr__(self):
return "LazyConditioning(on-disk)"
class LazyCondEntry:
"""One entry of a LazyConditioning — emitted by ResolutionBucket so each
bucket row pairs with exactly one conditioning entry."""
def __init__(self, lazy_cond, index):
self._cond = lazy_cond
self._index = index
def realize(self):
return self._cond.realize()[self._index]
def realize_entries(self):
return [self.realize()]
def __repr__(self):
return f"LazyCondEntry(index={self._index})"
class LazyBatchSamples:
"""The "samples" batch of one resolution bucket: N equal-shape rows backed by
on-disk LazyLatents (stored (1, *row_shape)), or already-real row tensors
when eager and lazy inputs are mixed. .shape/.dtype come from metadata;
realize_rows(indices) reads only the selected rows — the per-training-step
read unit."""
def __init__(self, rows):
self.rows = list(rows)
first = self.rows[0]
if isinstance(first, LazyLatent):
info = first["samples"]
row_shape, self.dtype = tuple(info.shape[1:]), info.dtype
else:
row_shape, self.dtype = tuple(first.shape), first.dtype
self.shape = torch.Size((len(self.rows), *row_shape))
self.ndim = len(self.shape)
def _row(self, i):
r = self.rows[int(i)]
return r.realize_samples()[0] if isinstance(r, LazyLatent) else r
def realize_rows(self, indices):
"""Read only the selected rows; return them stacked (len(indices), *row_shape)."""
return torch.stack([self._row(i) for i in indices], dim=0)
def realize(self):
"""Read and stack all rows: (N, *row_shape)."""
return self.realize_rows(range(len(self.rows)))
def __repr__(self):
return f"LazyBatchSamples(shape={tuple(self.shape)}, dtype={self.dtype})"
_LAZY_DATASET_TYPES = (LazyLatent, LazyConditioning, LazyCondEntry, LazyBatchSamples)
# Any op a lazy class doesn't define itself (indexing, iteration, math,
# truthiness, pickling) raises RealizeRequired instead of silently misbehaving.
for _cls in (LazyTensorInfo, *_LAZY_DATASET_TYPES):
for _op in ("__getitem__", "__iter__", "__len__", "__bool__", "__reduce__",
"__add__", "__radd__", "__sub__", "__rsub__", "__mul__", "__rmul__",
"__truediv__", "__matmul__", "__neg__"):
if _op not in _cls.__dict__:
setattr(_cls, _op, _need_realize)
def _realize_structure(obj):
"""Recursively replace lazy dataset objects with their realized (in-RAM)
values. Real tensors and plain values pass through unchanged."""
if isinstance(obj, _LAZY_DATASET_TYPES):
return obj.realize()
if isinstance(obj, dict):
return {k: _realize_structure(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_realize_structure(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_realize_structure(v) for v in obj)
return obj
class _ShardReader:
"""Random-access reader for a single shard.
Loads the small skeleton pickle eagerly; opens the safetensors file lazily
and uses safe_open's per-tensor random access so read_sample(i) only pulls
the tensors belonging to sample i. read_sample_lazy(i) pulls nothing — it
returns (LazyLatent, LazyConditioning) handles that read on demand.
"""
def __init__(self, shard_path, skeleton_path):
with open(skeleton_path, "rb") as f:
self.skeletons = pickle.load(f)
self.shard_path = shard_path
self._st = None
self._header = None
def _open(self):
if self._st is None:
self._st = safe_open(self.shard_path, framework="pt")
return self._st
@property
def header(self):
if self._header is None:
self._header = _read_safetensors_header(self.shard_path)
return self._header
def shape(self, key):
return tuple(self.header[key]["shape"])
def dtype(self, key):
return _ST_STR_TO_DTYPE[self.header[key]["dtype"]]
def get_tensor(self, key):
return self._open().get_tensor(key)
def get_slice(self, key):
return self._open().get_slice(key)
def __len__(self):
return len(self.skeletons)
def read_sample(self, local_idx):
"""Return (latent_dict, conditioning_list) for one sample, reading its
tensors eagerly."""
latent_skel, cond_skel = self.skeletons[local_idx]
st = self._open()
latent = _rejoin_tensors(latent_skel, st.get_tensor)
cond = _rejoin_tensors(cond_skel, st.get_tensor)
return latent, cond
def read_sample_lazy(self, local_idx):
"""Return (LazyLatent, LazyConditioning) handles for one sample — no
tensor bytes are read. The handles carry the sample's skeleton, so
latent["samples"].shape/.dtype come from the safetensors header and
realize() reads only this sample's tensors."""
latent_skel, cond_skel = self.skeletons[local_idx]
return LazyLatent(self, latent_skel), LazyConditioning(self, cond_skel)
class ResolutionBucket(io.ComfyNode):
"""Bucket latents and conditions by resolution for efficient batch training."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ResolutionBucket",
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
display_name="Resolution Bucket",
category="model/training",
description="Group latents and conditionings into buckets",
is_experimental=True,
is_input_list=True,
inputs=[
io.Latent.Input(
"latents",
tooltip="List of latent dicts to bucket by resolution.",
),
io.Conditioning.Input(
"conditioning",
tooltip="List of conditioning lists (must match latents length).",
),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="List of batched latent dicts, one per resolution bucket.",
),
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="List of condition lists, one per resolution bucket.",
),
],
)
@classmethod
def execute(cls, latents, conditioning):
# latents: list of latent dicts {"samples": (B, C, H, W)} and/or LazyLatent
# conditioning: list of conds (each a list of [tensor, dict] entries)
# and/or LazyConditioning
# Validate lengths match
if len(latents) != len(conditioning):
raise ValueError(
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
)
# Group rows by (H, W). Lazy latents are grouped by header metadata only
# (no tensor bytes read); buckets with any lazy row become LazyBatchSamples.
buckets = {} # (h, w) -> {"rows": [...], "conds": [...]}
any_lazy = False
for latent, cond in zip(latents, conditioning):
if isinstance(latent, LazyLatent):
info = latent["samples"]
if int(info.shape[0]) != 1:
raise RealizeRequired(
"ResolutionBucket: lazy latents with stored batch size > 1 "
"are not supported; insert a Realize Lazy Latents node first."
)
any_lazy = True
h, w = int(info.shape[-2]), int(info.shape[-1])
bucket = buckets.setdefault((h, w), {"rows": [], "conds": []})
bucket["rows"].append(latent)
bucket["conds"].append(
LazyCondEntry(cond, 0) if isinstance(cond, LazyConditioning) else cond[0]
)
else:
samples = latent["samples"] # (B, C, H, W) real tensor
h, w = int(samples.shape[-2]), int(samples.shape[-1])
bucket = buckets.setdefault((h, w), {"rows": [], "conds": []})
# cond is a list of entries with length == batch size
for i in range(samples.shape[0]):
bucket["rows"].append(samples[i])
bucket["conds"].append(
LazyCondEntry(cond, i) if isinstance(cond, LazyConditioning) else cond[i]
)
output_latents = [] # list[{"samples": (Bi, *row_shape)}]
output_conditions = [] # list[list[cond entry]] with Bi entries each
total = 0
for (h, w), bucket_data in buckets.items():
rows = bucket_data["rows"]
total += len(rows)
if any(isinstance(r, LazyLatent) for r in rows):
samples = LazyBatchSamples(rows)
else:
samples = torch.stack(rows, dim=0)
output_latents.append({"samples": samples})
output_conditions.append(bucket_data["conds"])
logging.info(f"Resolution bucket ({h}x{w}): {len(rows)} samples")
logging.info(
f"Created {len(buckets)} resolution buckets from {total} samples "
f"({'lazy' if any_lazy else 'eager'})"
)
return io.NodeOutput(output_latents, output_conditions)
class MakeTrainingDataset(io.ComfyNode):
"""Encode images with VAE and texts with CLIP to create a training dataset."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MakeTrainingDataset",
search_aliases=["encode dataset"],
display_name="Make Training Dataset",
category="model/training",
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
is_experimental=True,
is_input_list=True, # images and texts as lists
inputs=[
io.Image.Input("images", tooltip="List of images to encode."),
io.Vae.Input(
"vae", tooltip="VAE model for encoding images to latents."
),
io.Clip.Input(
"clip", tooltip="CLIP model for encoding text to conditioning."
),
io.String.Input(
"texts",
optional=True,
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
force_input=True
),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="List of latent dicts",
),
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="List of conditioning lists",
),
],
)
@classmethod
def execute(cls, images, vae, clip, texts=None):
# Extract scalars (vae and clip are single values wrapped in lists)
vae = vae[0]
clip = clip[0]
# Handle text list
num_images = len(images)
if texts is None or len(texts) == 0:
# Treat as [""] for unconditional training
texts = [""]
if len(texts) == 1 and num_images > 1:
# Repeat single text for all images
texts = texts * num_images
elif len(texts) != num_images:
raise ValueError(
f"Number of texts ({len(texts)}) does not match number of images ({num_images}). "
f"Text list should have length {num_images}, 1, or 0."
)
# Encode images with VAE
logging.info(f"Encoding {num_images} images with VAE...")
latents_list = [] # list[{"samples": tensor}]
for img_tensor in images:
# img_tensor is [1, H, W, 3]
latent_tensor = vae.encode(img_tensor[:, :, :, :3])
latents_list.append({"samples": latent_tensor})
# Encode texts with CLIP
logging.info(f"Encoding {len(texts)} texts with CLIP...")
conditioning_list = [] # list[list[cond]]
for text in texts:
if text == "":
cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
else:
tokens = clip.tokenize(text)
cond = clip.encode_from_tokens_scheduled(tokens)
conditioning_list.append(cond)
logging.info(
f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning."
)
return io.NodeOutput(latents_list, conditioning_list)
class SaveTrainingDataset(io.ComfyNode):
"""Save encoded training dataset (latents + conditioning) to disk."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveTrainingDataset",
search_aliases=["export dataset", "save dataset"],
display_name="Save Training Dataset",
category="model/training",
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
is_experimental=True,
is_output_node=True,
is_input_list=True, # Receive lists
inputs=[
io.Latent.Input(
"latents",
tooltip="List of latent dicts from MakeTrainingDataset.",
),
io.Conditioning.Input(
"conditioning",
tooltip="List of conditioning lists from MakeTrainingDataset.",
),
io.String.Input(
"folder_name",
default="training_dataset",
tooltip="Name of folder to save dataset (inside output directory).",
),
io.Int.Input(
"shard_size",
default=1000,
min=1,
max=100000,
tooltip="Number of samples per shard file.",
advanced=True,
),
],
outputs=[],
)
@classmethod
def execute(cls, latents, conditioning, folder_name, shard_size):
# Extract scalars
folder_name = folder_name[0]
shard_size = shard_size[0]
# latents: list[{"samples": tensor}]
# conditioning: list[list[[cond_tensor, dict]]] (encode_from_tokens_scheduled output;
# dicts may contain arbitrary extension types — Hook objects, floats, strings, etc.)
# Validate lengths match
if len(latents) != len(conditioning):
raise ValueError(
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). "
f"Something went wrong in dataset preparation."
)
# [TODO] can save to anywhere <- need to be resolve
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
os.makedirs(output_dir, exist_ok=True)
num_samples = len(latents)
num_shards = (num_samples + shard_size - 1) // shard_size
logging.info(
f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
)
for shard_idx in range(num_shards):
start_idx = shard_idx * shard_size
end_idx = min(start_idx + shard_size, num_samples)
# Per shard: one safetensors holding every tensor (bulk bytes, partial-loadable)
# plus one .skeleton.pkl holding the nested-structure shells with __tref__ markers.
shard_tensors = {}
shard_skeletons = [] # list of (latent_skeleton, cond_skeleton) per sample
for local_idx, i in enumerate(range(start_idx, end_idx)):
# Lazy inputs are realized per sample; at most one shard is in RAM.
latent_skel = _split_tensors(
_realize_structure(latents[i]), shard_tensors, f"s{local_idx}_lat"
)
cond_skel = _split_tensors(
_realize_structure(conditioning[i]), shard_tensors, f"s{local_idx}_cond"
)
shard_skeletons.append((latent_skel, cond_skel))
shard_path = os.path.join(output_dir, f"shard_{shard_idx:04d}.safetensors")
skeleton_path = os.path.join(
output_dir, f"shard_{shard_idx:04d}.skeleton.pkl"
)
safetensors.torch.save_file(shard_tensors, shard_path)
with open(skeleton_path, "wb") as f:
pickle.dump(shard_skeletons, f, protocol=pickle.HIGHEST_PROTOCOL)
logging.info(
f"Saved shard {shard_idx + 1}/{num_shards}: {end_idx - start_idx} samples, "
f"{len(shard_tensors)} tensors"
)
metadata = {
"num_samples": num_samples,
"num_shards": num_shards,
"shard_size": shard_size,
"format_version": 2,
}
metadata_path = os.path.join(output_dir, "metadata.json")
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
return io.NodeOutput()
class LoadTrainingDataset(io.ComfyNode):
"""Load encoded training dataset from disk as lazy references.
Outputs list[LazyLatent] and list[LazyConditioning] — one handle per sample,
near-zero RAM. Latent shapes/dtypes are readable from metadata (e.g. by
Resolution Bucket) without any I/O; tensor bytes are read per batch inside
the lazy-aware trainer. For any other consumer, insert the Realize Lazy
Latents / Realize Lazy Conditionings nodes to get standard in-RAM data.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadTrainingDataset",
search_aliases=["import dataset", "training data", "lazy", "streaming"],
display_name="Load Training Dataset",
category="model/training",
description="Load an encoded training dataset from disk as lazy references; tensors are read on demand during training instead of all at once.",
is_experimental=True,
inputs=[
io.String.Input(
"folder_name",
default="training_dataset",
tooltip="Name of folder containing the saved dataset (inside output directory).",
),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="List of latent dicts",
),
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="List of conditioning lists",
),
],
)
@classmethod
def execute(cls, folder_name):
dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
if not os.path.exists(dataset_dir):
raise ValueError(f"Dataset directory not found: {dataset_dir}")
shard_files = sorted(
f
for f in os.listdir(dataset_dir)
if f.startswith("shard_") and f.endswith(".safetensors")
)
if not shard_files:
raise ValueError(
f"No shard files found in {dataset_dir} "
f"(expected shard_*.safetensors + shard_*.skeleton.pkl)."
)
logging.info(f"Lazy-loading {len(shard_files)} shards from {dataset_dir}...")
all_latents = [] # list[LazyLatent]
all_conditioning = [] # list[LazyConditioning]
for shard_file in shard_files:
shard_path = os.path.join(dataset_dir, shard_file)
skeleton_path = os.path.join(
dataset_dir, shard_file[: -len(".safetensors")] + ".skeleton.pkl"
)
# Reads only the skeleton pickle + safetensors header, no tensor bytes.
reader = _ShardReader(shard_path, skeleton_path)
for local_idx in range(len(reader)):
latent, cond = reader.read_sample_lazy(local_idx)
all_latents.append(latent)
all_conditioning.append(cond)
logging.info(f"Indexed {shard_file}: {len(reader)} samples")
logging.info(
f"Lazy-loaded {len(all_latents)} samples from {dataset_dir} "
f"(no tensor data read yet)."
)
return io.NodeOutput(all_latents, all_conditioning)
class RealizeLazyLatents(io.ComfyNode):
"""Read all lazy latent tensors from disk into RAM, producing standard latent
dicts.
Insert before any node that is not lazy-aware (one that stacks or does tensor
math on the latents). Real latents pass through unchanged, so it is safe to
apply unconditionally.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RealizeLazyLatents",
search_aliases=["realize", "materialize", "load to ram", "realize latents"],
display_name="Realize Lazy Latents",
category="model/training",
description="Read all lazy latent tensors from disk into memory, producing standard in-RAM latent dicts.",
is_experimental=True,
is_input_list=True,
inputs=[
io.Latent.Input("latents", tooltip="Lazy (or real) latent dicts."),
],
outputs=[
io.Latent.Output(
display_name="latents",
is_output_list=True,
tooltip="Realized (in-RAM) latent dicts",
),
],
)
@classmethod
def execute(cls, latents):
real_latents = [_realize_structure(x) for x in latents]
logging.info(f"Realized {len(real_latents)} latents into RAM.")
return io.NodeOutput(real_latents)
class RealizeLazyConditionings(io.ComfyNode):
"""Read all lazy conditioning tensors from disk into RAM, producing standard
conditioning.
Insert before any node that is not lazy-aware. Real conditioning passes
through unchanged, so it is safe to apply unconditionally.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RealizeLazyConditionings",
search_aliases=["realize", "materialize", "load to ram", "realize conditioning"],
display_name="Realize Lazy Conditionings",
category="model/training",
description="Read all lazy conditioning tensors from disk into memory, producing standard in-RAM conditioning.",
is_experimental=True,
is_input_list=True,
inputs=[
io.Conditioning.Input(
"conditioning", tooltip="Lazy (or real) conditioning."
),
],
outputs=[
io.Conditioning.Output(
display_name="conditioning",
is_output_list=True,
tooltip="Realized (in-RAM) conditioning",
),
],
)
@classmethod
def execute(cls, conditioning):
real_conditioning = [_realize_structure(x) for x in conditioning]
logging.info(f"Realized {len(real_conditioning)} conditionings into RAM.")
return io.NodeOutput(real_conditioning)
# ========== Extension Setup ==========
class DatasetExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
# Data loading/saving nodes
LoadImageDataSetFromFolderNode,
LoadImageTextDataSetFromFolderNode,
SaveImageDataSetToFolderNode,
SaveImageTextDataSetToFolderNode,
# Image transform nodes
ResizeImagesByShorterEdgeNode,
ResizeImagesByLongerEdgeNode,
CenterCropImagesNode,
RandomCropImagesNode,
NormalizeImagesNode,
AdjustBrightnessNode,
AdjustContrastNode,
ShuffleDatasetNode,
ShuffleImageTextDatasetNode,
# Text transform nodes
TextToLowercaseNode,
TextToUppercaseNode,
TruncateTextNode,
AddTextPrefixNode,
AddTextSuffixNode,
ReplaceTextNode,
StripWhitespaceNode,
# Group processing examples
ImageDeduplicationNode,
ImageGridNode,
MergeImageListsNode,
MergeTextListsNode,
# Training dataset nodes
MakeTrainingDataset,
SaveTrainingDataset,
LoadTrainingDataset,
RealizeLazyLatents,
RealizeLazyConditionings,
ResolutionBucket,
]
async def comfy_entrypoint() -> DatasetExtension:
return DatasetExtension()