mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
2034 lines
74 KiB
Python
2034 lines
74 KiB
Python
import logging
|
|
import os
|
|
import json
|
|
import pickle
|
|
import struct
|
|
|
|
import numpy as np
|
|
import safetensors.torch
|
|
import torch
|
|
from PIL import Image
|
|
from safetensors import safe_open
|
|
from typing_extensions import override
|
|
|
|
import folder_paths
|
|
import node_helpers
|
|
from comfy_api.latest import ComfyExtension, io
|
|
|
|
|
|
def load_and_process_images(image_files, input_dir):
|
|
"""Utility function to load and process a list of images.
|
|
|
|
Args:
|
|
image_files: List of image filenames
|
|
input_dir: Base directory containing the images
|
|
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
|
|
|
Returns:
|
|
torch.Tensor: Batch of processed images
|
|
"""
|
|
if not image_files:
|
|
raise ValueError("No valid images found in input")
|
|
|
|
output_images = []
|
|
|
|
for file in image_files:
|
|
image_path = os.path.join(input_dir, file)
|
|
img = node_helpers.pillow(Image.open, image_path)
|
|
|
|
if img.mode == "I":
|
|
img = img.point(lambda i: i * (1 / 255))
|
|
img = img.convert("RGB")
|
|
img_array = np.array(img).astype(np.float32) / 255.0
|
|
img_tensor = torch.from_numpy(img_array)[None,]
|
|
output_images.append(img_tensor)
|
|
|
|
return output_images
|
|
|
|
|
|
class LoadImageDataSetFromFolderNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="LoadImageDataSetFromFolder",
|
|
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
|
display_name="Load Image (from Folder)",
|
|
category="image",
|
|
description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.",
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.Combo.Input(
|
|
"folder",
|
|
options=folder_paths.get_input_subfolders(),
|
|
tooltip="The folder to load images from.",
|
|
)
|
|
],
|
|
outputs=[
|
|
io.Image.Output(
|
|
display_name="images",
|
|
is_output_list=True,
|
|
tooltip="List of loaded images",
|
|
)
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, folder):
|
|
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
|
image_files = [
|
|
f
|
|
for f in os.listdir(sub_input_dir)
|
|
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
|
]
|
|
output_tensor = load_and_process_images(image_files, sub_input_dir)
|
|
return io.NodeOutput(output_tensor)
|
|
|
|
|
|
class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="LoadImageTextDataSetFromFolder",
|
|
search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"],
|
|
display_name="Load Image-Text (from Folder)",
|
|
category="image",
|
|
description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.",
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.Combo.Input(
|
|
"folder",
|
|
options=folder_paths.get_input_subfolders(),
|
|
tooltip="The folder to load images and text captions from.",
|
|
)
|
|
],
|
|
outputs=[
|
|
io.Image.Output(
|
|
display_name="images",
|
|
is_output_list=True,
|
|
tooltip="List of loaded images",
|
|
),
|
|
io.String.Output(
|
|
display_name="texts",
|
|
is_output_list=True,
|
|
tooltip="List of text captions",
|
|
),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, folder):
|
|
logging.info(f"Loading images from folder: {folder}")
|
|
|
|
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
|
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
|
|
|
image_files = []
|
|
for item in os.listdir(sub_input_dir):
|
|
path = os.path.join(sub_input_dir, item)
|
|
if any(item.lower().endswith(ext) for ext in valid_extensions):
|
|
image_files.append(path)
|
|
elif os.path.isdir(path):
|
|
# Support kohya-ss/sd-scripts folder structure
|
|
repeat = 1
|
|
if item.split("_")[0].isdigit():
|
|
repeat = int(item.split("_")[0])
|
|
image_files.extend(
|
|
[
|
|
os.path.join(path, f)
|
|
for f in os.listdir(path)
|
|
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
|
]
|
|
* repeat
|
|
)
|
|
|
|
caption_file_path = [
|
|
f.replace(os.path.splitext(f)[1], ".txt") for f in image_files
|
|
]
|
|
captions = []
|
|
for caption_file in caption_file_path:
|
|
caption_path = os.path.join(sub_input_dir, caption_file)
|
|
if os.path.exists(caption_path):
|
|
with open(caption_path, "r", encoding="utf-8") as f:
|
|
caption = f.read().strip()
|
|
captions.append(caption)
|
|
else:
|
|
captions.append("")
|
|
|
|
output_tensor = load_and_process_images(image_files, sub_input_dir)
|
|
|
|
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
|
return io.NodeOutput(output_tensor, captions)
|
|
|
|
|
|
def save_images_to_folder(image_list, output_dir, prefix="image", overwrite=True):
|
|
"""Utility function to save a list of image tensors to disk.
|
|
|
|
Args:
|
|
image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W])
|
|
output_dir: Directory to save images to
|
|
prefix: Filename prefix
|
|
|
|
Returns:
|
|
List of saved filenames
|
|
"""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
saved_files = []
|
|
|
|
for idx, img_tensor in enumerate(image_list):
|
|
# Handle different tensor shapes
|
|
if isinstance(img_tensor, torch.Tensor):
|
|
# Remove batch dimension if present [1, H, W, C] -> [H, W, C]
|
|
if img_tensor.dim() == 4 and img_tensor.shape[0] == 1:
|
|
img_tensor = img_tensor.squeeze(0)
|
|
|
|
# If tensor is [C, H, W], permute to [H, W, C]
|
|
if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]:
|
|
if (
|
|
img_tensor.shape[0] <= 4
|
|
and img_tensor.shape[1] > 4
|
|
and img_tensor.shape[2] > 4
|
|
):
|
|
img_tensor = img_tensor.permute(1, 2, 0)
|
|
|
|
# Convert to numpy and scale to 0-255
|
|
img_array = img_tensor.cpu().numpy()
|
|
img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8)
|
|
|
|
# Convert to PIL Image
|
|
img = Image.fromarray(img_array)
|
|
else:
|
|
raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}")
|
|
|
|
# Save image
|
|
if overwrite:
|
|
filename = f"{prefix}_{idx:05d}.png"
|
|
else:
|
|
_, _, counter, _, resolved_prefix = folder_paths.get_save_image_path(prefix, output_dir)
|
|
filename = f"{resolved_prefix}_{counter:05}_{idx:05d}.png"
|
|
filepath = os.path.join(output_dir, filename)
|
|
img.save(filepath)
|
|
saved_files.append(filename)
|
|
|
|
return saved_files
|
|
|
|
|
|
class SaveImageDataSetToFolderNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SaveImageDataSetToFolder",
|
|
search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"],
|
|
display_name="Save Image (to Folder) (DEPRECATED)",
|
|
category="image",
|
|
description="Save a dataset of images to a specified folder. Supported formats: PNG.",
|
|
is_experimental=True,
|
|
is_output_node=True,
|
|
is_input_list=True, # Receive images as list
|
|
inputs=[
|
|
io.Image.Input("images", tooltip="List of images to save."),
|
|
io.String.Input(
|
|
"folder_name",
|
|
default="dataset",
|
|
tooltip="Name of the folder to save images to (inside output directory).",
|
|
),
|
|
io.String.Input(
|
|
"filename_prefix",
|
|
default="image",
|
|
tooltip="Prefix for saved image filenames.",
|
|
advanced=True,
|
|
),
|
|
io.Combo.Input(
|
|
"mode",
|
|
default="overwrite",
|
|
options=["overwrite", "increment"],
|
|
tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting."
|
|
),
|
|
],
|
|
outputs=[],
|
|
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, images, folder_name, filename_prefix, mode):
|
|
# Extract scalar values
|
|
folder_name = folder_name[0]
|
|
filename_prefix = filename_prefix[0]
|
|
mode = mode[0]
|
|
|
|
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
|
saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite')
|
|
|
|
logging.info(f"Saved {len(saved_files)} images to {output_dir}.")
|
|
return io.NodeOutput()
|
|
|
|
|
|
class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SaveImageTextDataSetToFolder",
|
|
search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"],
|
|
display_name="Save Image-Text (to Folder)",
|
|
category="image",
|
|
description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.",
|
|
is_experimental=True,
|
|
is_output_node=True,
|
|
is_input_list=True, # Receive both images and texts as lists
|
|
inputs=[
|
|
io.Image.Input("images", tooltip="List of images to save."),
|
|
io.String.Input("texts",
|
|
optional=True,
|
|
force_input=True,
|
|
tooltip="List of text captions to save."
|
|
),
|
|
io.String.Input(
|
|
"folder_name",
|
|
default="dataset",
|
|
tooltip="Name of the folder to save images to (inside output directory).",
|
|
),
|
|
io.String.Input(
|
|
"filename_prefix",
|
|
default="image",
|
|
tooltip="Prefix for saved image filenames.",
|
|
advanced=True,
|
|
),
|
|
io.Combo.Input(
|
|
"mode",
|
|
default="overwrite",
|
|
options=["overwrite", "increment"],
|
|
tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting."
|
|
),
|
|
],
|
|
outputs=[],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, images, folder_name, filename_prefix, mode, texts=None):
|
|
# Extract scalar values
|
|
folder_name = folder_name[0]
|
|
filename_prefix = filename_prefix[0]
|
|
mode = mode[0]
|
|
|
|
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
|
saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite')
|
|
|
|
# Save captions
|
|
if texts:
|
|
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
|
caption_filename = filename.replace(".png", ".txt")
|
|
caption_path = os.path.join(output_dir, caption_filename)
|
|
with open(caption_path, "w", encoding="utf-8") as f:
|
|
f.write(caption)
|
|
|
|
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
|
|
return io.NodeOutput()
|
|
|
|
|
|
# ========== Helper Functions for Transform Nodes ==========
|
|
|
|
|
|
def tensor_to_pil(img_tensor):
|
|
"""Convert tensor to PIL Image."""
|
|
if img_tensor.dim() == 4 and img_tensor.shape[0] == 1:
|
|
img_tensor = img_tensor.squeeze(0)
|
|
img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
|
return Image.fromarray(img_array)
|
|
|
|
|
|
def pil_to_tensor(img):
|
|
"""Convert PIL Image to tensor."""
|
|
img_array = np.array(img).astype(np.float32) / 255.0
|
|
return torch.from_numpy(img_array)[None,]
|
|
|
|
|
|
# ========== Base Classes for Transform Nodes ==========
|
|
|
|
|
|
class ImageProcessingNode(io.ComfyNode):
|
|
"""Base class for image processing nodes that operate on images.
|
|
|
|
Child classes should set:
|
|
node_id: Unique node identifier (required)
|
|
search_aliases: List of search aliases (optional)
|
|
display_name: Display name (optional, defaults to node_id)
|
|
description: Node description (optional)
|
|
extra_inputs: List of additional io.Input objects beyond "images" (optional)
|
|
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
|
is_output_list: True (list output) or False (single output) (optional, default True)
|
|
is_deprecated: True if the node is deprecated (optional, default False)
|
|
|
|
Child classes must implement ONE of:
|
|
_process(cls, image, **kwargs) -> tensor (for single-item processing)
|
|
_group_process(cls, images, **kwargs) -> list[tensor] (for group processing)
|
|
"""
|
|
|
|
node_id = None
|
|
search_aliases = []
|
|
display_name = None
|
|
description = None
|
|
extra_inputs = []
|
|
is_group_process = None # None = auto-detect, True/False = explicit
|
|
is_output_list = None # None = auto-detect based on processing mode
|
|
is_deprecated = False
|
|
@classmethod
|
|
def _detect_processing_mode(cls):
|
|
"""Detect whether this node uses group or individual processing.
|
|
|
|
Returns:
|
|
bool: True if group processing, False if individual processing
|
|
"""
|
|
# Explicit setting takes precedence
|
|
if cls.is_group_process is not None:
|
|
return cls.is_group_process
|
|
|
|
# Check which method is overridden by looking at the defining class in MRO
|
|
base_class = ImageProcessingNode
|
|
|
|
# Find which class in MRO defines _process
|
|
process_definer = None
|
|
for klass in cls.__mro__:
|
|
if "_process" in klass.__dict__:
|
|
process_definer = klass
|
|
break
|
|
|
|
# Find which class in MRO defines _group_process
|
|
group_definer = None
|
|
for klass in cls.__mro__:
|
|
if "_group_process" in klass.__dict__:
|
|
group_definer = klass
|
|
break
|
|
|
|
# Check what was overridden (not defined in base class)
|
|
has_process = process_definer is not None and process_definer is not base_class
|
|
has_group = group_definer is not None and group_definer is not base_class
|
|
|
|
if has_process and has_group:
|
|
raise ValueError(
|
|
f"{cls.__name__}: Cannot override both _process and _group_process. "
|
|
"Override only one, or set is_group_process explicitly."
|
|
)
|
|
if not has_process and not has_group:
|
|
raise ValueError(
|
|
f"{cls.__name__}: Must override either _process or _group_process"
|
|
)
|
|
|
|
return has_group
|
|
|
|
@classmethod
|
|
def _ensure_image_list(cls, images):
|
|
"""Normalize to a flat list of [1, H, W, C] tensors."""
|
|
if isinstance(images, torch.Tensor):
|
|
if images.ndim != 4:
|
|
raise ValueError(f"Expected 4D image tensor, got shape {tuple(images.shape)}")
|
|
return [images[i:i+1] for i in range(images.shape[0])]
|
|
|
|
flat = []
|
|
for item in images:
|
|
if not isinstance(item, torch.Tensor) or item.ndim != 4:
|
|
raise ValueError(f"Expected 4D image tensor, got {type(item).__name__} shape {getattr(item, 'shape', None)}")
|
|
flat.extend([item[i:i+1] for i in range(item.shape[0])])
|
|
return flat
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
if cls.node_id is None:
|
|
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
|
|
|
|
is_group = cls._detect_processing_mode()
|
|
|
|
# Auto-detect is_output_list if not explicitly set
|
|
# Single processing: False (backend collects results into list)
|
|
# Group processing: True by default (can be False for single-output nodes)
|
|
output_is_list = (
|
|
cls.is_output_list if cls.is_output_list is not None else is_group
|
|
)
|
|
|
|
inputs = [
|
|
io.Image.Input(
|
|
"images",
|
|
tooltip=(
|
|
"List of images to process." if is_group else "Image to process."
|
|
),
|
|
)
|
|
]
|
|
inputs.extend(cls.extra_inputs)
|
|
|
|
return io.Schema(
|
|
node_id=cls.node_id,
|
|
search_aliases=cls.search_aliases,
|
|
display_name=cls.display_name or cls.node_id,
|
|
category=cls.category,
|
|
description=cls.description,
|
|
is_experimental=True,
|
|
is_input_list=is_group, # True for group, False for individual
|
|
inputs=inputs,
|
|
outputs=[
|
|
io.Image.Output(
|
|
display_name="images",
|
|
is_output_list=output_is_list,
|
|
tooltip="Processed images",
|
|
)
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, images, **kwargs):
|
|
"""Execute the node. Routes to _process or _group_process based on mode."""
|
|
is_group = cls._detect_processing_mode()
|
|
|
|
if is_group:
|
|
images = cls._ensure_image_list(images)
|
|
|
|
# Extract scalar values from lists for parameters
|
|
params = {}
|
|
for k, v in kwargs.items():
|
|
if isinstance(v, list) and len(v) == 1:
|
|
params[k] = v[0]
|
|
else:
|
|
params[k] = v
|
|
|
|
if is_group:
|
|
# Group processing: images is list, call _group_process
|
|
result = cls._group_process(images, **params)
|
|
else:
|
|
# Individual processing: images is single item, call _process
|
|
result = cls._process(images, **params)
|
|
|
|
return io.NodeOutput(result)
|
|
|
|
@classmethod
|
|
def _process(cls, image, **kwargs):
|
|
"""Override this method for single-item processing.
|
|
|
|
Args:
|
|
image: tensor - Single image tensor
|
|
**kwargs: Additional parameters (already extracted from lists)
|
|
|
|
Returns:
|
|
tensor - Processed image
|
|
"""
|
|
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
|
|
|
@classmethod
|
|
def _group_process(cls, images, **kwargs):
|
|
"""Override this method for group processing.
|
|
|
|
Args:
|
|
images: list[tensor] - List of image tensors
|
|
**kwargs: Additional parameters (already extracted from lists)
|
|
|
|
Returns:
|
|
list[tensor] - Processed images
|
|
"""
|
|
raise NotImplementedError(
|
|
f"{cls.__name__} must implement _group_process method"
|
|
)
|
|
|
|
|
|
class TextProcessingNode(io.ComfyNode):
|
|
"""Base class for text processing nodes that operate on texts.
|
|
|
|
Child classes should set:
|
|
node_id: Unique node identifier (required)
|
|
search_aliases: List of search aliases (optional)
|
|
display_name: Display name (optional, defaults to node_id)
|
|
description: Node description (optional)
|
|
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
|
|
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
|
is_output_list: True (list output) or False (single output) (optional, default True)
|
|
is_deprecated: True if the node is deprecated (optional, default False)
|
|
|
|
Child classes must implement ONE of:
|
|
_process(cls, text, **kwargs) -> str (for single-item processing)
|
|
_group_process(cls, texts, **kwargs) -> list[str] (for group processing)
|
|
"""
|
|
|
|
node_id = None
|
|
search_aliases = []
|
|
display_name = None
|
|
description = None
|
|
extra_inputs = []
|
|
is_group_process = None # None = auto-detect, True/False = explicit
|
|
is_output_list = None # None = auto-detect based on processing mode
|
|
is_deprecated = False
|
|
@classmethod
|
|
def _detect_processing_mode(cls):
|
|
"""Detect whether this node uses group or individual processing.
|
|
|
|
Returns:
|
|
bool: True if group processing, False if individual processing
|
|
"""
|
|
# Explicit setting takes precedence
|
|
if cls.is_group_process is not None:
|
|
return cls.is_group_process
|
|
|
|
# Check which method is overridden by looking at the defining class in MRO
|
|
base_class = TextProcessingNode
|
|
|
|
# Find which class in MRO defines _process
|
|
process_definer = None
|
|
for klass in cls.__mro__:
|
|
if "_process" in klass.__dict__:
|
|
process_definer = klass
|
|
break
|
|
|
|
# Find which class in MRO defines _group_process
|
|
group_definer = None
|
|
for klass in cls.__mro__:
|
|
if "_group_process" in klass.__dict__:
|
|
group_definer = klass
|
|
break
|
|
|
|
# Check what was overridden (not defined in base class)
|
|
has_process = process_definer is not None and process_definer is not base_class
|
|
has_group = group_definer is not None and group_definer is not base_class
|
|
|
|
if has_process and has_group:
|
|
raise ValueError(
|
|
f"{cls.__name__}: Cannot override both _process and _group_process. "
|
|
"Override only one, or set is_group_process explicitly."
|
|
)
|
|
if not has_process and not has_group:
|
|
raise ValueError(
|
|
f"{cls.__name__}: Must override either _process or _group_process"
|
|
)
|
|
|
|
return has_group
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
if cls.node_id is None:
|
|
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
|
|
|
|
is_group = cls._detect_processing_mode()
|
|
|
|
inputs = [
|
|
io.String.Input(
|
|
"texts",
|
|
tooltip="List of texts to process." if is_group else "Text to process.",
|
|
)
|
|
]
|
|
inputs.extend(cls.extra_inputs)
|
|
|
|
return io.Schema(
|
|
node_id=cls.node_id,
|
|
display_name=cls.display_name or cls.node_id,
|
|
category="text",
|
|
is_experimental=True,
|
|
is_input_list=is_group, # True for group, False for individual
|
|
inputs=inputs,
|
|
outputs=[
|
|
io.String.Output(
|
|
display_name="texts",
|
|
is_output_list=cls.is_output_list,
|
|
tooltip="Processed texts",
|
|
)
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, texts, **kwargs):
|
|
"""Execute the node. Routes to _process or _group_process based on mode."""
|
|
is_group = cls._detect_processing_mode()
|
|
|
|
# Extract scalar values from lists for parameters
|
|
params = {}
|
|
for k, v in kwargs.items():
|
|
if isinstance(v, list) and len(v) == 1:
|
|
params[k] = v[0]
|
|
else:
|
|
params[k] = v
|
|
|
|
if is_group:
|
|
# Group processing: texts is list, call _group_process
|
|
result = cls._group_process(texts, **params)
|
|
else:
|
|
# Individual processing: texts is single item, call _process
|
|
result = cls._process(texts, **params)
|
|
|
|
# Wrap result based on is_output_list
|
|
if cls.is_output_list:
|
|
# Result should already be a list (or will be for individual)
|
|
return io.NodeOutput(result if is_group else [result])
|
|
else:
|
|
# Single output - wrap in list for NodeOutput
|
|
return io.NodeOutput([result])
|
|
|
|
@classmethod
|
|
def _process(cls, text, **kwargs):
|
|
"""Override this method for single-item processing.
|
|
|
|
Args:
|
|
text: str - Single text string
|
|
**kwargs: Additional parameters (already extracted from lists)
|
|
|
|
Returns:
|
|
str - Processed text
|
|
"""
|
|
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
|
|
|
@classmethod
|
|
def _group_process(cls, texts, **kwargs):
|
|
"""Override this method for group processing.
|
|
|
|
Args:
|
|
texts: list[str] - List of text strings
|
|
**kwargs: Additional parameters (already extracted from lists)
|
|
|
|
Returns:
|
|
list[str] - Processed texts
|
|
"""
|
|
raise NotImplementedError(
|
|
f"{cls.__name__} must implement _group_process method"
|
|
)
|
|
|
|
|
|
# ========== Image Transform Nodes ==========
|
|
|
|
|
|
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
|
node_id = "ResizeImagesByShorterEdge"
|
|
display_name = "Resize Images by Shorter Edge (DEPRECATED)"
|
|
category = "image/transform"
|
|
description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio."
|
|
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension
|
|
extra_inputs = [
|
|
io.Int.Input(
|
|
"shorter_edge",
|
|
default=512,
|
|
min=1,
|
|
max=8192,
|
|
tooltip="Target dimension for the shorter edge.",
|
|
),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, image, shorter_edge):
|
|
img = tensor_to_pil(image)
|
|
w, h = img.size
|
|
if w < h:
|
|
new_w = shorter_edge
|
|
new_h = int(h * (shorter_edge / w))
|
|
else:
|
|
new_h = shorter_edge
|
|
new_w = int(w * (shorter_edge / h))
|
|
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
return pil_to_tensor(img)
|
|
|
|
|
|
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
|
node_id = "ResizeImagesByLongerEdge"
|
|
display_name = "Resize Images by Longer Edge (DEPRECATED)"
|
|
category = "image/transform"
|
|
description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio."
|
|
is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension
|
|
extra_inputs = [
|
|
io.Int.Input(
|
|
"longer_edge",
|
|
default=1024,
|
|
min=1,
|
|
max=8192,
|
|
tooltip="Target dimension for the longer edge.",
|
|
),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, image, longer_edge):
|
|
resized_images = []
|
|
for image_i in image:
|
|
img = tensor_to_pil(image_i)
|
|
w, h = img.size
|
|
if w > h:
|
|
new_w = longer_edge
|
|
new_h = int(h * (longer_edge / w))
|
|
else:
|
|
new_h = longer_edge
|
|
new_w = int(w * (longer_edge / h))
|
|
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
resized_images.append(pil_to_tensor(img))
|
|
return torch.cat(resized_images, dim=0)
|
|
|
|
|
|
class CenterCropImagesNode(ImageProcessingNode):
|
|
node_id = "CenterCropImages"
|
|
search_aliases=["crop", "cut", "trim"]
|
|
display_name="Crop Image (Center)"
|
|
category="image/transform"
|
|
description = "Center crop an image to the specified dimensions."
|
|
extra_inputs = [
|
|
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
|
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, image, width, height):
|
|
img = tensor_to_pil(image)
|
|
left = max(0, (img.width - width) // 2)
|
|
top = max(0, (img.height - height) // 2)
|
|
right = min(img.width, left + width)
|
|
bottom = min(img.height, top + height)
|
|
img = img.crop((left, top, right, bottom))
|
|
return pil_to_tensor(img)
|
|
|
|
|
|
class RandomCropImagesNode(ImageProcessingNode):
|
|
node_id = "RandomCropImages"
|
|
search_aliases=["crop", "cut", "trim"]
|
|
display_name = "Crop Image (Random)"
|
|
category="image/transform"
|
|
description = "Randomly crop an image to the specified dimensions."
|
|
|
|
extra_inputs = [
|
|
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
|
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
|
io.Int.Input(
|
|
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
|
),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, image, width, height, seed):
|
|
np.random.seed(seed % (2**32 - 1))
|
|
img = tensor_to_pil(image)
|
|
max_left = max(0, img.width - width)
|
|
max_top = max(0, img.height - height)
|
|
left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
|
|
top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
|
|
right = min(img.width, left + width)
|
|
bottom = min(img.height, top + height)
|
|
img = img.crop((left, top, right, bottom))
|
|
return pil_to_tensor(img)
|
|
|
|
|
|
class NormalizeImagesNode(ImageProcessingNode):
|
|
node_id = "NormalizeImages"
|
|
search_aliases=["normalize", "normalize colors"]
|
|
display_name = "Normalize Image Colors"
|
|
category = "image/color"
|
|
description = "Normalize images using mean and standard deviation."
|
|
extra_inputs = [
|
|
io.Float.Input(
|
|
"mean",
|
|
default=0.5,
|
|
min=0.0,
|
|
max=1.0,
|
|
tooltip="Mean value for normalization.",
|
|
advanced=True,
|
|
),
|
|
io.Float.Input(
|
|
"std",
|
|
default=0.5,
|
|
min=0.001,
|
|
max=1.0,
|
|
tooltip="Standard deviation for normalization.",
|
|
advanced=True,
|
|
),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, image, mean, std):
|
|
return (image - mean) / std
|
|
|
|
|
|
class AdjustBrightnessNode(ImageProcessingNode):
|
|
node_id = "AdjustBrightness"
|
|
search_aliases=["brightness"]
|
|
display_name = "Adjust Brightness"
|
|
category="image/adjustments"
|
|
description = "Adjust the brightness of an image."
|
|
extra_inputs = [
|
|
io.Float.Input(
|
|
"factor",
|
|
default=1.0,
|
|
min=0.0,
|
|
max=2.0,
|
|
tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.",
|
|
),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, image, factor):
|
|
return (image * factor).clamp(0.0, 1.0)
|
|
|
|
|
|
class AdjustContrastNode(ImageProcessingNode):
|
|
node_id = "AdjustContrast"
|
|
search_aliases=["contrast"]
|
|
display_name = "Adjust Contrast"
|
|
category="image/adjustments"
|
|
description = "Adjust the contrast of an image."
|
|
extra_inputs = [
|
|
io.Float.Input(
|
|
"factor",
|
|
default=1.0,
|
|
min=0.0,
|
|
max=2.0,
|
|
tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.",
|
|
),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, image, factor):
|
|
return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0)
|
|
|
|
|
|
class ShuffleDatasetNode(ImageProcessingNode):
|
|
node_id = "ShuffleDataset"
|
|
search_aliases=["shuffle", "randomize", "mix"]
|
|
display_name = "Shuffle Images List"
|
|
category = "image/batch"
|
|
description = "Randomly shuffle the order of images in a list."
|
|
is_group_process = True # Requires full list to shuffle
|
|
extra_inputs = [
|
|
io.Int.Input(
|
|
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
|
),
|
|
]
|
|
|
|
@classmethod
|
|
def _group_process(cls, images, seed):
|
|
np.random.seed(seed % (2**32 - 1))
|
|
indices = np.random.permutation(len(images))
|
|
return [images[i] for i in indices]
|
|
|
|
|
|
class ShuffleImageTextDatasetNode(io.ComfyNode):
|
|
"""Special node that shuffles both images and texts together."""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="ShuffleImageTextDataset",
|
|
search_aliases=["shuffle", "randomize", "mix"],
|
|
display_name = "Shuffle Pairs of Image-Text",
|
|
category = "image/batch",
|
|
description = "Randomly shuffle the order of pairs of image-text in a list.",
|
|
is_experimental=True,
|
|
is_input_list=True,
|
|
inputs=[
|
|
io.Image.Input("images", tooltip="List of images to shuffle."),
|
|
io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True),
|
|
io.Int.Input(
|
|
"seed",
|
|
default=0,
|
|
min=0,
|
|
max=0xFFFFFFFFFFFFFFFF,
|
|
tooltip="Random seed.",
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Image.Output(
|
|
display_name="images",
|
|
is_output_list=True,
|
|
tooltip="Shuffled images",
|
|
),
|
|
io.String.Output(
|
|
display_name="texts", is_output_list=True, tooltip="Shuffled texts"
|
|
),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, images, texts, seed):
|
|
seed = seed[0] # Extract scalar
|
|
np.random.seed(seed % (2**32 - 1))
|
|
indices = np.random.permutation(len(images))
|
|
shuffled_images = [images[i] for i in indices]
|
|
shuffled_texts = [texts[i] for i in indices]
|
|
return io.NodeOutput(shuffled_images, shuffled_texts)
|
|
|
|
|
|
# ========== Text Transform Nodes ==========
|
|
|
|
|
|
class TextToLowercaseNode(TextProcessingNode):
|
|
node_id = "TextToLowercase"
|
|
search_aliases=["lowercase"]
|
|
display_name = "Convert Text to Lowercase (DEPRECATED)"
|
|
category = "text"
|
|
description = "Convert text to lowercase."
|
|
is_deprecated = True # This node is superseded by the Convert Text Case node
|
|
|
|
@classmethod
|
|
def _process(cls, text):
|
|
return text.lower()
|
|
|
|
|
|
class TextToUppercaseNode(TextProcessingNode):
|
|
node_id = "TextToUppercase"
|
|
search_aliases=["uppercase"]
|
|
display_name = "Convert Text to Uppercase (DEPRECATED)"
|
|
category = "text"
|
|
description = "Convert text to uppercase."
|
|
is_deprecated = True # This node is superseded by the Convert Text Case node
|
|
|
|
@classmethod
|
|
def _process(cls, text):
|
|
return text.upper()
|
|
|
|
|
|
class TruncateTextNode(TextProcessingNode):
|
|
node_id = "TruncateText"
|
|
search_aliases=["truncate", "cut", "shorten"]
|
|
display_name = "Truncate Text"
|
|
category = "text"
|
|
description = "Truncate text to a maximum length."
|
|
extra_inputs = [
|
|
io.Int.Input(
|
|
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
|
|
),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, text, max_length):
|
|
return text[:max_length]
|
|
|
|
|
|
class AddTextPrefixNode(TextProcessingNode):
|
|
node_id = "AddTextPrefix"
|
|
display_name = "Add Text Prefix (DEPRECATED)"
|
|
category = "text"
|
|
description = "Add a prefix to all texts."
|
|
is_deprecated = True # This node is superseded by the Concatenate Text node
|
|
extra_inputs = [
|
|
io.String.Input("prefix", default="", tooltip="Prefix to add."),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, text, prefix):
|
|
return prefix + text
|
|
|
|
|
|
class AddTextSuffixNode(TextProcessingNode):
|
|
node_id = "AddTextSuffix"
|
|
display_name = "Add Text Suffix (DEPRECATED)"
|
|
category = "text"
|
|
description = "Add a suffix to all texts."
|
|
is_deprecated = True # This node is superseded by the Concatenate Text node
|
|
extra_inputs = [
|
|
io.String.Input("suffix", default="", tooltip="Suffix to add."),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, text, suffix):
|
|
return text + suffix
|
|
|
|
|
|
class ReplaceTextNode(TextProcessingNode):
|
|
node_id = "ReplaceText"
|
|
display_name = "Replace Text (DEPRECATED)"
|
|
category = "text"
|
|
description = "Replace text in all texts."
|
|
is_deprecated = True # This node is superseded by the other Replace Text node
|
|
extra_inputs = [
|
|
io.String.Input("find", default="", tooltip="Text to find."),
|
|
io.String.Input("replace", default="", tooltip="Text to replace with."),
|
|
]
|
|
|
|
@classmethod
|
|
def _process(cls, text, find, replace):
|
|
return text.replace(find, replace)
|
|
|
|
|
|
class StripWhitespaceNode(TextProcessingNode):
|
|
node_id = "StripWhitespace"
|
|
display_name = "Strip Whitespace (DEPRECATED)"
|
|
category = "text"
|
|
description = "Strip leading and trailing whitespace from all texts."
|
|
is_deprecated = True # This node is superseded by the Trim Text node
|
|
|
|
@classmethod
|
|
def _process(cls, text):
|
|
return text.strip()
|
|
|
|
|
|
# ========== Group Processing Example Nodes ==========
|
|
|
|
|
|
class ImageDeduplicationNode(ImageProcessingNode):
|
|
"""Remove duplicate or very similar images from a list using perceptual hashing."""
|
|
|
|
node_id = "ImageDeduplication"
|
|
search_aliases=["deduplicate", "remove duplicates", "similarity filter"]
|
|
display_name = "Deduplicate Images"
|
|
category = "image/batch"
|
|
description = "Remove duplicate or very similar images from a list."
|
|
is_group_process = True # Requires full list to compare images
|
|
extra_inputs = [
|
|
io.Float.Input(
|
|
"similarity_threshold",
|
|
default=0.95,
|
|
min=0.0,
|
|
max=1.0,
|
|
tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.",
|
|
advanced=True,
|
|
),
|
|
]
|
|
|
|
@classmethod
|
|
def _group_process(cls, images, similarity_threshold):
|
|
"""Remove duplicate images using perceptual hashing."""
|
|
if len(images) == 0:
|
|
return []
|
|
|
|
# Compute simple perceptual hash for each image
|
|
def compute_hash(img_tensor):
|
|
"""Compute a simple perceptual hash by resizing to 8x8 and comparing to average."""
|
|
img = tensor_to_pil(img_tensor)
|
|
# Resize to 8x8
|
|
img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L")
|
|
# Get pixels
|
|
pixels = list(img_small.getdata())
|
|
# Compute average
|
|
avg = sum(pixels) / len(pixels)
|
|
# Create hash (1 if above average, 0 otherwise)
|
|
hash_bits = "".join("1" if p > avg else "0" for p in pixels)
|
|
return hash_bits
|
|
|
|
def hamming_distance(hash1, hash2):
|
|
"""Compute Hamming distance between two hash strings."""
|
|
return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
|
|
|
|
# Compute hashes for all images
|
|
hashes = [compute_hash(img) for img in images]
|
|
|
|
# Find duplicates
|
|
keep_indices = []
|
|
for i in range(len(images)):
|
|
is_duplicate = False
|
|
for j in keep_indices:
|
|
# Compare hashes
|
|
distance = hamming_distance(hashes[i], hashes[j])
|
|
similarity = 1.0 - (distance / 64.0) # 64 bits total
|
|
if similarity >= similarity_threshold:
|
|
is_duplicate = True
|
|
logging.info(
|
|
f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping"
|
|
)
|
|
break
|
|
|
|
if not is_duplicate:
|
|
keep_indices.append(i)
|
|
|
|
# Return only unique images
|
|
unique_images = [images[i] for i in keep_indices]
|
|
logging.info(
|
|
f"Deduplication: kept {len(unique_images)} out of {len(images)} images"
|
|
)
|
|
return unique_images
|
|
|
|
|
|
class ImageGridNode(ImageProcessingNode):
|
|
"""Combine multiple images into a single grid/collage."""
|
|
|
|
node_id = "ImageGrid"
|
|
search_aliases=["grid", "collage", "combine"]
|
|
display_name = "Make Image Grid"
|
|
category="image/batch"
|
|
description = "Arrange multiple images into a grid layout."
|
|
is_group_process = True # Requires full list to create grid
|
|
is_output_list = False # Outputs single grid image
|
|
extra_inputs = [
|
|
io.Int.Input(
|
|
"columns",
|
|
default=4,
|
|
min=1,
|
|
max=20,
|
|
tooltip="Number of columns in the grid.",
|
|
),
|
|
io.Int.Input(
|
|
"cell_width",
|
|
default=256,
|
|
min=32,
|
|
max=2048,
|
|
tooltip="Width of each cell in the grid.",
|
|
advanced=True,
|
|
),
|
|
io.Int.Input(
|
|
"cell_height",
|
|
default=256,
|
|
min=32,
|
|
max=2048,
|
|
tooltip="Height of each cell in the grid.",
|
|
advanced=True,
|
|
),
|
|
io.Int.Input(
|
|
"padding", default=4, min=0, max=50, tooltip="Padding between images.", advanced=True
|
|
),
|
|
]
|
|
|
|
@classmethod
|
|
def _group_process(cls, images, columns, cell_width, cell_height, padding):
|
|
"""Arrange images into a grid."""
|
|
if len(images) == 0:
|
|
raise ValueError("Cannot create grid from empty image list")
|
|
|
|
# Calculate grid dimensions
|
|
num_images = len(images)
|
|
rows = (num_images + columns - 1) // columns # Ceiling division
|
|
|
|
# Calculate total grid size
|
|
grid_width = columns * cell_width + (columns - 1) * padding
|
|
grid_height = rows * cell_height + (rows - 1) * padding
|
|
|
|
# Create blank grid
|
|
grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0))
|
|
|
|
# Place images
|
|
for idx, img_tensor in enumerate(images):
|
|
row = idx // columns
|
|
col = idx % columns
|
|
|
|
# Convert to PIL and resize to cell size
|
|
img = tensor_to_pil(img_tensor)
|
|
img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS)
|
|
|
|
# Calculate position
|
|
x = col * (cell_width + padding)
|
|
y = row * (cell_height + padding)
|
|
|
|
# Paste into grid
|
|
grid.paste(img, (x, y))
|
|
|
|
logging.info(
|
|
f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})"
|
|
)
|
|
return pil_to_tensor(grid)
|
|
|
|
|
|
class MergeImageListsNode(ImageProcessingNode):
|
|
"""Merge multiple image lists into a single list."""
|
|
|
|
node_id = "MergeImageLists"
|
|
search_aliases=["list", "merge list", "make list"]
|
|
display_name = "Merge Image Lists (DEPRECATED)"
|
|
category = "image/batch"
|
|
description = "Concatenate multiple image lists into one."
|
|
is_group_process = True # Receives images as list
|
|
is_deprecated = True # This node is superseded by the Create List node
|
|
|
|
@classmethod
|
|
def _group_process(cls, images):
|
|
"""Simply return the images list (already merged by input handling)."""
|
|
# When multiple list inputs are connected, they're concatenated
|
|
# For now, this is a simple pass-through
|
|
logging.info(f"Merged image list contains {len(images)} images")
|
|
return images
|
|
|
|
|
|
class MergeTextListsNode(TextProcessingNode):
|
|
"""Merge multiple text lists into a single list."""
|
|
|
|
node_id = "MergeTextLists"
|
|
display_name = "Merge Text Lists (DEPRECATED)"
|
|
category = "text"
|
|
description = "Concatenate multiple text lists into one."
|
|
is_group_process = True # Receives texts as list
|
|
is_deprecated = True # This node is superseded by the Create List node
|
|
|
|
@classmethod
|
|
def _group_process(cls, texts):
|
|
"""Simply return the texts list (already merged by input handling)."""
|
|
# When multiple list inputs are connected, they're concatenated
|
|
# For now, this is a simple pass-through
|
|
logging.info(f"Merged text list contains {len(texts)} texts")
|
|
return texts
|
|
|
|
|
|
# ========== Training Dataset Nodes ==========
|
|
|
|
|
|
# Sentinel key used in the "skeleton" to mark where a tensor lived in the
|
|
# original nested structure. The skeleton is pickled; tensors live in the
|
|
# accompanying safetensors file under the referenced key.
|
|
_TREF_KEY = "__tref__"
|
|
|
|
|
|
def _split_tensors(obj, out_tensors, prefix):
|
|
"""Walk obj recursively. Pull tensors out into out_tensors (keyed by f"{prefix}_{N}")
|
|
and return a "skeleton" with the same structure but each tensor replaced by
|
|
{"__tref__": key}. Everything that isn't a tensor / dict / list / tuple
|
|
(Hook objects, floats, strings, custom extension types, ...) passes through
|
|
untouched and will be handled by pickle.
|
|
"""
|
|
if isinstance(obj, torch.Tensor):
|
|
key = f"{prefix}_{len(out_tensors)}"
|
|
out_tensors[key] = obj.detach().cpu().clone()
|
|
return {_TREF_KEY: key}
|
|
elif isinstance(obj, dict):
|
|
return {k: _split_tensors(v, out_tensors, prefix) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [_split_tensors(v, out_tensors, prefix) for v in obj]
|
|
elif isinstance(obj, tuple):
|
|
return tuple(_split_tensors(v, out_tensors, prefix) for v in obj)
|
|
return obj
|
|
|
|
|
|
def _rejoin_tensors(obj, tensor_getter):
|
|
"""Inverse of _split_tensors. Walk skeleton, fetch tensors via tensor_getter(key)
|
|
wherever a {"__tref__": ...} marker appears.
|
|
"""
|
|
if isinstance(obj, dict):
|
|
if len(obj) == 1 and _TREF_KEY in obj:
|
|
return tensor_getter(obj[_TREF_KEY])
|
|
return {k: _rejoin_tensors(v, tensor_getter) for k, v in obj.items()}
|
|
if isinstance(obj, list):
|
|
return [_rejoin_tensors(v, tensor_getter) for v in obj]
|
|
if isinstance(obj, tuple):
|
|
return tuple(_rejoin_tensors(v, tensor_getter) for v in obj)
|
|
return obj
|
|
|
|
|
|
# safetensors dtype strings -> torch dtype, used to read shapes/dtypes from the
|
|
# header without loading any tensor bytes.
|
|
_ST_STR_TO_DTYPE = {
|
|
"F64": torch.float64, "F32": torch.float32, "F16": torch.float16,
|
|
"BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32,
|
|
"I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool,
|
|
}
|
|
|
|
|
|
def _read_safetensors_header(path):
|
|
"""Read the safetensors header (dtype + shape per tensor key) without reading
|
|
any tensor data. The file starts with an 8-byte little-endian header length
|
|
followed by that many bytes of JSON."""
|
|
with open(path, "rb") as f:
|
|
n = struct.unpack("<Q", f.read(8))[0]
|
|
header = json.loads(f.read(n))
|
|
header.pop("__metadata__", None)
|
|
return header
|
|
|
|
|
|
class RealizeRequired(RuntimeError):
|
|
"""Raised when lazy on-disk dataset data is used where real tensors are
|
|
needed. Realize it first: .realize() in code, or the Realize Lazy Latents /
|
|
Realize Lazy Conditionings nodes in a workflow."""
|
|
|
|
|
|
def _need_realize(self, *args, **kwargs):
|
|
raise RealizeRequired(
|
|
f"{type(self).__name__} is lazy on-disk data and does not support this "
|
|
f"operation. Realize it first (.realize() or a Realize node)."
|
|
)
|
|
|
|
|
|
class LazyTensorInfo:
|
|
"""Shape/dtype of one on-disk tensor, read from the safetensors header — no
|
|
tensor bytes. Anything beyond .shape/.dtype/.ndim raises RealizeRequired."""
|
|
|
|
def __init__(self, shape, dtype):
|
|
self.shape = torch.Size(shape)
|
|
self.dtype = dtype
|
|
self.ndim = len(self.shape)
|
|
|
|
def __repr__(self):
|
|
return f"LazyTensorInfo(shape={tuple(self.shape)}, dtype={self.dtype})"
|
|
|
|
__getattr__ = _need_realize
|
|
|
|
|
|
class LazyLatent:
|
|
"""One dataset sample's latent dict ({"samples": tensor, ...}) on disk.
|
|
|
|
Carries the sample's skeleton, so latent["samples"] serves shape/dtype from
|
|
the safetensors header with zero I/O. Tensor values require realization:
|
|
realize() -> real latent dict, realize_samples() -> real "samples" tensor.
|
|
Realization is never cached; a persistent list[LazyLatent] stays near-zero
|
|
RAM (the OS page cache handles re-read locality).
|
|
"""
|
|
|
|
def __init__(self, reader, skeleton):
|
|
self._reader = reader
|
|
self._skel = skeleton
|
|
|
|
def __getitem__(self, name):
|
|
v = self._skel[name]
|
|
if isinstance(v, dict) and len(v) == 1 and _TREF_KEY in v:
|
|
key = v[_TREF_KEY]
|
|
return LazyTensorInfo(self._reader.shape(key), self._reader.dtype(key))
|
|
return v # plain non-tensor value (e.g. batch_index)
|
|
|
|
def realize(self):
|
|
"""Read this sample's tensors from disk; return the real latent dict."""
|
|
return _rejoin_tensors(self._skel, self._reader.get_tensor)
|
|
|
|
def realize_samples(self):
|
|
"""Read and return just the real "samples" tensor."""
|
|
return self._reader.get_tensor(self._skel["samples"][_TREF_KEY])
|
|
|
|
def __repr__(self):
|
|
info = self["samples"]
|
|
return f"LazyLatent(samples={tuple(info.shape)}, dtype={info.dtype})"
|
|
|
|
|
|
class LazyConditioning:
|
|
"""One dataset sample's conditioning on disk. Content is an arbitrary pickled
|
|
structure, so the only access is realize() -> list of [tensor, dict] entries."""
|
|
|
|
def __init__(self, reader, skeleton):
|
|
self._reader = reader
|
|
self._skel = skeleton
|
|
|
|
def realize(self):
|
|
"""Read the full conditioning for this sample from disk."""
|
|
return _rejoin_tensors(self._skel, self._reader.get_tensor)
|
|
|
|
realize_entries = realize # a realized conditioning IS its entry list
|
|
|
|
def __repr__(self):
|
|
return "LazyConditioning(on-disk)"
|
|
|
|
|
|
class LazyCondEntry:
|
|
"""One entry of a LazyConditioning — emitted by ResolutionBucket so each
|
|
bucket row pairs with exactly one conditioning entry."""
|
|
|
|
def __init__(self, lazy_cond, index):
|
|
self._cond = lazy_cond
|
|
self._index = index
|
|
|
|
def realize(self):
|
|
return self._cond.realize()[self._index]
|
|
|
|
def realize_entries(self):
|
|
return [self.realize()]
|
|
|
|
def __repr__(self):
|
|
return f"LazyCondEntry(index={self._index})"
|
|
|
|
|
|
class LazyBatchSamples:
|
|
"""The "samples" batch of one resolution bucket: N equal-shape rows backed by
|
|
on-disk LazyLatents (stored (1, *row_shape)), or already-real row tensors
|
|
when eager and lazy inputs are mixed. .shape/.dtype come from metadata;
|
|
realize_rows(indices) reads only the selected rows — the per-training-step
|
|
read unit."""
|
|
|
|
def __init__(self, rows):
|
|
self.rows = list(rows)
|
|
first = self.rows[0]
|
|
if isinstance(first, LazyLatent):
|
|
info = first["samples"]
|
|
row_shape, self.dtype = tuple(info.shape[1:]), info.dtype
|
|
else:
|
|
row_shape, self.dtype = tuple(first.shape), first.dtype
|
|
self.shape = torch.Size((len(self.rows), *row_shape))
|
|
self.ndim = len(self.shape)
|
|
|
|
def _row(self, i):
|
|
r = self.rows[int(i)]
|
|
return r.realize_samples()[0] if isinstance(r, LazyLatent) else r
|
|
|
|
def realize_rows(self, indices):
|
|
"""Read only the selected rows; return them stacked (len(indices), *row_shape)."""
|
|
return torch.stack([self._row(i) for i in indices], dim=0)
|
|
|
|
def realize(self):
|
|
"""Read and stack all rows: (N, *row_shape)."""
|
|
return self.realize_rows(range(len(self.rows)))
|
|
|
|
def __repr__(self):
|
|
return f"LazyBatchSamples(shape={tuple(self.shape)}, dtype={self.dtype})"
|
|
|
|
|
|
_LAZY_DATASET_TYPES = (LazyLatent, LazyConditioning, LazyCondEntry, LazyBatchSamples)
|
|
|
|
# Any op a lazy class doesn't define itself (indexing, iteration, math,
|
|
# truthiness, pickling) raises RealizeRequired instead of silently misbehaving.
|
|
for _cls in (LazyTensorInfo, *_LAZY_DATASET_TYPES):
|
|
for _op in ("__getitem__", "__iter__", "__len__", "__bool__", "__reduce__",
|
|
"__add__", "__radd__", "__sub__", "__rsub__", "__mul__", "__rmul__",
|
|
"__truediv__", "__matmul__", "__neg__"):
|
|
if _op not in _cls.__dict__:
|
|
setattr(_cls, _op, _need_realize)
|
|
|
|
|
|
def _realize_structure(obj):
|
|
"""Recursively replace lazy dataset objects with their realized (in-RAM)
|
|
values. Real tensors and plain values pass through unchanged."""
|
|
if isinstance(obj, _LAZY_DATASET_TYPES):
|
|
return obj.realize()
|
|
if isinstance(obj, dict):
|
|
return {k: _realize_structure(v) for k, v in obj.items()}
|
|
if isinstance(obj, list):
|
|
return [_realize_structure(v) for v in obj]
|
|
if isinstance(obj, tuple):
|
|
return tuple(_realize_structure(v) for v in obj)
|
|
return obj
|
|
|
|
|
|
class _ShardReader:
|
|
"""Random-access reader for a single shard.
|
|
|
|
Loads the small skeleton pickle eagerly; opens the safetensors file lazily
|
|
and uses safe_open's per-tensor random access so read_sample(i) only pulls
|
|
the tensors belonging to sample i. read_sample_lazy(i) pulls nothing — it
|
|
returns (LazyLatent, LazyConditioning) handles that read on demand.
|
|
"""
|
|
|
|
def __init__(self, shard_path, skeleton_path):
|
|
with open(skeleton_path, "rb") as f:
|
|
self.skeletons = pickle.load(f)
|
|
self.shard_path = shard_path
|
|
self._st = None
|
|
self._header = None
|
|
|
|
def _open(self):
|
|
if self._st is None:
|
|
self._st = safe_open(self.shard_path, framework="pt")
|
|
return self._st
|
|
|
|
@property
|
|
def header(self):
|
|
if self._header is None:
|
|
self._header = _read_safetensors_header(self.shard_path)
|
|
return self._header
|
|
|
|
def shape(self, key):
|
|
return tuple(self.header[key]["shape"])
|
|
|
|
def dtype(self, key):
|
|
return _ST_STR_TO_DTYPE[self.header[key]["dtype"]]
|
|
|
|
def get_tensor(self, key):
|
|
return self._open().get_tensor(key)
|
|
|
|
def get_slice(self, key):
|
|
return self._open().get_slice(key)
|
|
|
|
def __len__(self):
|
|
return len(self.skeletons)
|
|
|
|
def read_sample(self, local_idx):
|
|
"""Return (latent_dict, conditioning_list) for one sample, reading its
|
|
tensors eagerly."""
|
|
latent_skel, cond_skel = self.skeletons[local_idx]
|
|
st = self._open()
|
|
latent = _rejoin_tensors(latent_skel, st.get_tensor)
|
|
cond = _rejoin_tensors(cond_skel, st.get_tensor)
|
|
return latent, cond
|
|
|
|
def read_sample_lazy(self, local_idx):
|
|
"""Return (LazyLatent, LazyConditioning) handles for one sample — no
|
|
tensor bytes are read. The handles carry the sample's skeleton, so
|
|
latent["samples"].shape/.dtype come from the safetensors header and
|
|
realize() reads only this sample's tensors."""
|
|
latent_skel, cond_skel = self.skeletons[local_idx]
|
|
return LazyLatent(self, latent_skel), LazyConditioning(self, cond_skel)
|
|
|
|
|
|
class ResolutionBucket(io.ComfyNode):
|
|
"""Bucket latents and conditions by resolution for efficient batch training."""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="ResolutionBucket",
|
|
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
|
|
display_name="Resolution Bucket",
|
|
category="model/training",
|
|
description="Group latents and conditionings into buckets",
|
|
is_experimental=True,
|
|
is_input_list=True,
|
|
inputs=[
|
|
io.Latent.Input(
|
|
"latents",
|
|
tooltip="List of latent dicts to bucket by resolution.",
|
|
),
|
|
io.Conditioning.Input(
|
|
"conditioning",
|
|
tooltip="List of conditioning lists (must match latents length).",
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Latent.Output(
|
|
display_name="latents",
|
|
is_output_list=True,
|
|
tooltip="List of batched latent dicts, one per resolution bucket.",
|
|
),
|
|
io.Conditioning.Output(
|
|
display_name="conditioning",
|
|
is_output_list=True,
|
|
tooltip="List of condition lists, one per resolution bucket.",
|
|
),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, latents, conditioning):
|
|
# latents: list of latent dicts {"samples": (B, C, H, W)} and/or LazyLatent
|
|
# conditioning: list of conds (each a list of [tensor, dict] entries)
|
|
# and/or LazyConditioning
|
|
|
|
# Validate lengths match
|
|
if len(latents) != len(conditioning):
|
|
raise ValueError(
|
|
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
|
|
)
|
|
|
|
# Group rows by (H, W). Lazy latents are grouped by header metadata only
|
|
# (no tensor bytes read); buckets with any lazy row become LazyBatchSamples.
|
|
buckets = {} # (h, w) -> {"rows": [...], "conds": [...]}
|
|
any_lazy = False
|
|
|
|
for latent, cond in zip(latents, conditioning):
|
|
if isinstance(latent, LazyLatent):
|
|
info = latent["samples"]
|
|
if int(info.shape[0]) != 1:
|
|
raise RealizeRequired(
|
|
"ResolutionBucket: lazy latents with stored batch size > 1 "
|
|
"are not supported; insert a Realize Lazy Latents node first."
|
|
)
|
|
any_lazy = True
|
|
h, w = int(info.shape[-2]), int(info.shape[-1])
|
|
bucket = buckets.setdefault((h, w), {"rows": [], "conds": []})
|
|
bucket["rows"].append(latent)
|
|
bucket["conds"].append(
|
|
LazyCondEntry(cond, 0) if isinstance(cond, LazyConditioning) else cond[0]
|
|
)
|
|
else:
|
|
samples = latent["samples"] # (B, C, H, W) real tensor
|
|
h, w = int(samples.shape[-2]), int(samples.shape[-1])
|
|
bucket = buckets.setdefault((h, w), {"rows": [], "conds": []})
|
|
# cond is a list of entries with length == batch size
|
|
for i in range(samples.shape[0]):
|
|
bucket["rows"].append(samples[i])
|
|
bucket["conds"].append(
|
|
LazyCondEntry(cond, i) if isinstance(cond, LazyConditioning) else cond[i]
|
|
)
|
|
|
|
output_latents = [] # list[{"samples": (Bi, *row_shape)}]
|
|
output_conditions = [] # list[list[cond entry]] with Bi entries each
|
|
|
|
total = 0
|
|
for (h, w), bucket_data in buckets.items():
|
|
rows = bucket_data["rows"]
|
|
total += len(rows)
|
|
if any(isinstance(r, LazyLatent) for r in rows):
|
|
samples = LazyBatchSamples(rows)
|
|
else:
|
|
samples = torch.stack(rows, dim=0)
|
|
output_latents.append({"samples": samples})
|
|
output_conditions.append(bucket_data["conds"])
|
|
logging.info(f"Resolution bucket ({h}x{w}): {len(rows)} samples")
|
|
|
|
logging.info(
|
|
f"Created {len(buckets)} resolution buckets from {total} samples "
|
|
f"({'lazy' if any_lazy else 'eager'})"
|
|
)
|
|
return io.NodeOutput(output_latents, output_conditions)
|
|
|
|
|
|
class MakeTrainingDataset(io.ComfyNode):
|
|
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="MakeTrainingDataset",
|
|
search_aliases=["encode dataset"],
|
|
display_name="Make Training Dataset",
|
|
category="model/training",
|
|
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
|
|
is_experimental=True,
|
|
is_input_list=True, # images and texts as lists
|
|
inputs=[
|
|
io.Image.Input("images", tooltip="List of images to encode."),
|
|
io.Vae.Input(
|
|
"vae", tooltip="VAE model for encoding images to latents."
|
|
),
|
|
io.Clip.Input(
|
|
"clip", tooltip="CLIP model for encoding text to conditioning."
|
|
),
|
|
io.String.Input(
|
|
"texts",
|
|
optional=True,
|
|
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
|
|
force_input=True
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Latent.Output(
|
|
display_name="latents",
|
|
is_output_list=True,
|
|
tooltip="List of latent dicts",
|
|
),
|
|
io.Conditioning.Output(
|
|
display_name="conditioning",
|
|
is_output_list=True,
|
|
tooltip="List of conditioning lists",
|
|
),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, images, vae, clip, texts=None):
|
|
# Extract scalars (vae and clip are single values wrapped in lists)
|
|
vae = vae[0]
|
|
clip = clip[0]
|
|
|
|
# Handle text list
|
|
num_images = len(images)
|
|
|
|
if texts is None or len(texts) == 0:
|
|
# Treat as [""] for unconditional training
|
|
texts = [""]
|
|
|
|
if len(texts) == 1 and num_images > 1:
|
|
# Repeat single text for all images
|
|
texts = texts * num_images
|
|
elif len(texts) != num_images:
|
|
raise ValueError(
|
|
f"Number of texts ({len(texts)}) does not match number of images ({num_images}). "
|
|
f"Text list should have length {num_images}, 1, or 0."
|
|
)
|
|
|
|
# Encode images with VAE
|
|
logging.info(f"Encoding {num_images} images with VAE...")
|
|
latents_list = [] # list[{"samples": tensor}]
|
|
for img_tensor in images:
|
|
# img_tensor is [1, H, W, 3]
|
|
latent_tensor = vae.encode(img_tensor[:, :, :, :3])
|
|
latents_list.append({"samples": latent_tensor})
|
|
|
|
# Encode texts with CLIP
|
|
logging.info(f"Encoding {len(texts)} texts with CLIP...")
|
|
conditioning_list = [] # list[list[cond]]
|
|
for text in texts:
|
|
if text == "":
|
|
cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
|
|
else:
|
|
tokens = clip.tokenize(text)
|
|
cond = clip.encode_from_tokens_scheduled(tokens)
|
|
conditioning_list.append(cond)
|
|
|
|
logging.info(
|
|
f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning."
|
|
)
|
|
return io.NodeOutput(latents_list, conditioning_list)
|
|
|
|
|
|
class SaveTrainingDataset(io.ComfyNode):
|
|
"""Save encoded training dataset (latents + conditioning) to disk."""
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SaveTrainingDataset",
|
|
search_aliases=["export dataset", "save dataset"],
|
|
display_name="Save Training Dataset",
|
|
category="model/training",
|
|
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
|
|
is_experimental=True,
|
|
is_output_node=True,
|
|
is_input_list=True, # Receive lists
|
|
inputs=[
|
|
io.Latent.Input(
|
|
"latents",
|
|
tooltip="List of latent dicts from MakeTrainingDataset.",
|
|
),
|
|
io.Conditioning.Input(
|
|
"conditioning",
|
|
tooltip="List of conditioning lists from MakeTrainingDataset.",
|
|
),
|
|
io.String.Input(
|
|
"folder_name",
|
|
default="training_dataset",
|
|
tooltip="Name of folder to save dataset (inside output directory).",
|
|
),
|
|
io.Int.Input(
|
|
"shard_size",
|
|
default=1000,
|
|
min=1,
|
|
max=100000,
|
|
tooltip="Number of samples per shard file.",
|
|
advanced=True,
|
|
),
|
|
],
|
|
outputs=[],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, latents, conditioning, folder_name, shard_size):
|
|
# Extract scalars
|
|
folder_name = folder_name[0]
|
|
shard_size = shard_size[0]
|
|
|
|
# latents: list[{"samples": tensor}]
|
|
# conditioning: list[list[[cond_tensor, dict]]] (encode_from_tokens_scheduled output;
|
|
# dicts may contain arbitrary extension types — Hook objects, floats, strings, etc.)
|
|
|
|
# Validate lengths match
|
|
if len(latents) != len(conditioning):
|
|
raise ValueError(
|
|
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). "
|
|
f"Something went wrong in dataset preparation."
|
|
)
|
|
|
|
# [TODO] can save to anywhere <- need to be resolve
|
|
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
num_samples = len(latents)
|
|
num_shards = (num_samples + shard_size - 1) // shard_size
|
|
|
|
logging.info(
|
|
f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
|
|
)
|
|
|
|
for shard_idx in range(num_shards):
|
|
start_idx = shard_idx * shard_size
|
|
end_idx = min(start_idx + shard_size, num_samples)
|
|
|
|
# Per shard: one safetensors holding every tensor (bulk bytes, partial-loadable)
|
|
# plus one .skeleton.pkl holding the nested-structure shells with __tref__ markers.
|
|
shard_tensors = {}
|
|
shard_skeletons = [] # list of (latent_skeleton, cond_skeleton) per sample
|
|
|
|
for local_idx, i in enumerate(range(start_idx, end_idx)):
|
|
# Lazy inputs are realized per sample; at most one shard is in RAM.
|
|
latent_skel = _split_tensors(
|
|
_realize_structure(latents[i]), shard_tensors, f"s{local_idx}_lat"
|
|
)
|
|
cond_skel = _split_tensors(
|
|
_realize_structure(conditioning[i]), shard_tensors, f"s{local_idx}_cond"
|
|
)
|
|
shard_skeletons.append((latent_skel, cond_skel))
|
|
|
|
shard_path = os.path.join(output_dir, f"shard_{shard_idx:04d}.safetensors")
|
|
skeleton_path = os.path.join(
|
|
output_dir, f"shard_{shard_idx:04d}.skeleton.pkl"
|
|
)
|
|
|
|
safetensors.torch.save_file(shard_tensors, shard_path)
|
|
with open(skeleton_path, "wb") as f:
|
|
pickle.dump(shard_skeletons, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
logging.info(
|
|
f"Saved shard {shard_idx + 1}/{num_shards}: {end_idx - start_idx} samples, "
|
|
f"{len(shard_tensors)} tensors"
|
|
)
|
|
|
|
metadata = {
|
|
"num_samples": num_samples,
|
|
"num_shards": num_shards,
|
|
"shard_size": shard_size,
|
|
"format_version": 2,
|
|
}
|
|
metadata_path = os.path.join(output_dir, "metadata.json")
|
|
with open(metadata_path, "w") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
|
|
return io.NodeOutput()
|
|
|
|
|
|
class LoadTrainingDataset(io.ComfyNode):
|
|
"""Load encoded training dataset from disk as lazy references.
|
|
|
|
Outputs list[LazyLatent] and list[LazyConditioning] — one handle per sample,
|
|
near-zero RAM. Latent shapes/dtypes are readable from metadata (e.g. by
|
|
Resolution Bucket) without any I/O; tensor bytes are read per batch inside
|
|
the lazy-aware trainer. For any other consumer, insert the Realize Lazy
|
|
Latents / Realize Lazy Conditionings nodes to get standard in-RAM data.
|
|
"""
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="LoadTrainingDataset",
|
|
search_aliases=["import dataset", "training data", "lazy", "streaming"],
|
|
display_name="Load Training Dataset",
|
|
category="model/training",
|
|
description="Load an encoded training dataset from disk as lazy references; tensors are read on demand during training instead of all at once.",
|
|
is_experimental=True,
|
|
inputs=[
|
|
io.String.Input(
|
|
"folder_name",
|
|
default="training_dataset",
|
|
tooltip="Name of folder containing the saved dataset (inside output directory).",
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Latent.Output(
|
|
display_name="latents",
|
|
is_output_list=True,
|
|
tooltip="List of latent dicts",
|
|
),
|
|
io.Conditioning.Output(
|
|
display_name="conditioning",
|
|
is_output_list=True,
|
|
tooltip="List of conditioning lists",
|
|
),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, folder_name):
|
|
dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
|
|
|
if not os.path.exists(dataset_dir):
|
|
raise ValueError(f"Dataset directory not found: {dataset_dir}")
|
|
|
|
shard_files = sorted(
|
|
f
|
|
for f in os.listdir(dataset_dir)
|
|
if f.startswith("shard_") and f.endswith(".safetensors")
|
|
)
|
|
|
|
if not shard_files:
|
|
raise ValueError(
|
|
f"No shard files found in {dataset_dir} "
|
|
f"(expected shard_*.safetensors + shard_*.skeleton.pkl)."
|
|
)
|
|
|
|
logging.info(f"Lazy-loading {len(shard_files)} shards from {dataset_dir}...")
|
|
|
|
all_latents = [] # list[LazyLatent]
|
|
all_conditioning = [] # list[LazyConditioning]
|
|
|
|
for shard_file in shard_files:
|
|
shard_path = os.path.join(dataset_dir, shard_file)
|
|
skeleton_path = os.path.join(
|
|
dataset_dir, shard_file[: -len(".safetensors")] + ".skeleton.pkl"
|
|
)
|
|
|
|
# Reads only the skeleton pickle + safetensors header, no tensor bytes.
|
|
reader = _ShardReader(shard_path, skeleton_path)
|
|
for local_idx in range(len(reader)):
|
|
latent, cond = reader.read_sample_lazy(local_idx)
|
|
all_latents.append(latent)
|
|
all_conditioning.append(cond)
|
|
|
|
logging.info(f"Indexed {shard_file}: {len(reader)} samples")
|
|
|
|
logging.info(
|
|
f"Lazy-loaded {len(all_latents)} samples from {dataset_dir} "
|
|
f"(no tensor data read yet)."
|
|
)
|
|
return io.NodeOutput(all_latents, all_conditioning)
|
|
|
|
|
|
class RealizeLazyLatents(io.ComfyNode):
|
|
"""Read all lazy latent tensors from disk into RAM, producing standard latent
|
|
dicts.
|
|
|
|
Insert before any node that is not lazy-aware (one that stacks or does tensor
|
|
math on the latents). Real latents pass through unchanged, so it is safe to
|
|
apply unconditionally.
|
|
"""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="RealizeLazyLatents",
|
|
search_aliases=["realize", "materialize", "load to ram", "realize latents"],
|
|
display_name="Realize Lazy Latents",
|
|
category="model/training",
|
|
description="Read all lazy latent tensors from disk into memory, producing standard in-RAM latent dicts.",
|
|
is_experimental=True,
|
|
is_input_list=True,
|
|
inputs=[
|
|
io.Latent.Input("latents", tooltip="Lazy (or real) latent dicts."),
|
|
],
|
|
outputs=[
|
|
io.Latent.Output(
|
|
display_name="latents",
|
|
is_output_list=True,
|
|
tooltip="Realized (in-RAM) latent dicts",
|
|
),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, latents):
|
|
real_latents = [_realize_structure(x) for x in latents]
|
|
logging.info(f"Realized {len(real_latents)} latents into RAM.")
|
|
return io.NodeOutput(real_latents)
|
|
|
|
|
|
class RealizeLazyConditionings(io.ComfyNode):
|
|
"""Read all lazy conditioning tensors from disk into RAM, producing standard
|
|
conditioning.
|
|
|
|
Insert before any node that is not lazy-aware. Real conditioning passes
|
|
through unchanged, so it is safe to apply unconditionally.
|
|
"""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="RealizeLazyConditionings",
|
|
search_aliases=["realize", "materialize", "load to ram", "realize conditioning"],
|
|
display_name="Realize Lazy Conditionings",
|
|
category="model/training",
|
|
description="Read all lazy conditioning tensors from disk into memory, producing standard in-RAM conditioning.",
|
|
is_experimental=True,
|
|
is_input_list=True,
|
|
inputs=[
|
|
io.Conditioning.Input(
|
|
"conditioning", tooltip="Lazy (or real) conditioning."
|
|
),
|
|
],
|
|
outputs=[
|
|
io.Conditioning.Output(
|
|
display_name="conditioning",
|
|
is_output_list=True,
|
|
tooltip="Realized (in-RAM) conditioning",
|
|
),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, conditioning):
|
|
real_conditioning = [_realize_structure(x) for x in conditioning]
|
|
logging.info(f"Realized {len(real_conditioning)} conditionings into RAM.")
|
|
return io.NodeOutput(real_conditioning)
|
|
|
|
|
|
# ========== Extension Setup ==========
|
|
|
|
|
|
class DatasetExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
# Data loading/saving nodes
|
|
LoadImageDataSetFromFolderNode,
|
|
LoadImageTextDataSetFromFolderNode,
|
|
SaveImageDataSetToFolderNode,
|
|
SaveImageTextDataSetToFolderNode,
|
|
# Image transform nodes
|
|
ResizeImagesByShorterEdgeNode,
|
|
ResizeImagesByLongerEdgeNode,
|
|
CenterCropImagesNode,
|
|
RandomCropImagesNode,
|
|
NormalizeImagesNode,
|
|
AdjustBrightnessNode,
|
|
AdjustContrastNode,
|
|
ShuffleDatasetNode,
|
|
ShuffleImageTextDatasetNode,
|
|
# Text transform nodes
|
|
TextToLowercaseNode,
|
|
TextToUppercaseNode,
|
|
TruncateTextNode,
|
|
AddTextPrefixNode,
|
|
AddTextSuffixNode,
|
|
ReplaceTextNode,
|
|
StripWhitespaceNode,
|
|
# Group processing examples
|
|
ImageDeduplicationNode,
|
|
ImageGridNode,
|
|
MergeImageListsNode,
|
|
MergeTextListsNode,
|
|
# Training dataset nodes
|
|
MakeTrainingDataset,
|
|
SaveTrainingDataset,
|
|
LoadTrainingDataset,
|
|
RealizeLazyLatents,
|
|
RealizeLazyConditionings,
|
|
ResolutionBucket,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> DatasetExtension:
|
|
return DatasetExtension()
|