mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 06:40:16 +08:00
Add encoded dataset caching mechanism
This commit is contained in:
parent
c5a8ec3f67
commit
d5aea27817
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -50,7 +51,7 @@ class LoadImageDataSetFromFolderNode:
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE_LIST",)
|
RETURN_TYPES = ("IMAGE_LIST",)
|
||||||
FUNCTION = "load_images"
|
FUNCTION = "load_images"
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "dataset"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
DESCRIPTION = "Loads a batch of images from a directory for training."
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ class LoadImageTextDataSetFromFolderNode:
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE_LIST", "TEXT_LIST",)
|
RETURN_TYPES = ("IMAGE_LIST", "TEXT_LIST",)
|
||||||
FUNCTION = "load_images"
|
FUNCTION = "load_images"
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "dataset"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
DESCRIPTION = "Loads a batch of images and caption from a directory for training."
|
DESCRIPTION = "Loads a batch of images and caption from a directory for training."
|
||||||
|
|
||||||
@ -115,8 +116,6 @@ class LoadImageTextDataSetFromFolderNode:
|
|||||||
else:
|
else:
|
||||||
captions.append("")
|
captions.append("")
|
||||||
|
|
||||||
width = width if width != -1 else None
|
|
||||||
height = height if height != -1 else None
|
|
||||||
output_tensor = load_and_process_images(image_files, sub_input_dir)
|
output_tensor = load_and_process_images(image_files, sub_input_dir)
|
||||||
|
|
||||||
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
||||||
@ -181,7 +180,7 @@ class SaveImageDataSetToFolderNode:
|
|||||||
RETURN_TYPES = ()
|
RETURN_TYPES = ()
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
FUNCTION = "save_images"
|
FUNCTION = "save_images"
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "dataset"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
DESCRIPTION = "Saves a batch of images to a directory."
|
DESCRIPTION = "Saves a batch of images to a directory."
|
||||||
|
|
||||||
@ -208,7 +207,7 @@ class SaveImageTextDataSetToFolderNode:
|
|||||||
RETURN_TYPES = ()
|
RETURN_TYPES = ()
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
FUNCTION = "save_images"
|
FUNCTION = "save_images"
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "dataset"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
DESCRIPTION = "Saves a batch of images and captions to a directory."
|
DESCRIPTION = "Saves a batch of images and captions to a directory."
|
||||||
|
|
||||||
@ -232,7 +231,7 @@ class SaveImageTextDataSetToFolderNode:
|
|||||||
class ImageProcessingNode:
|
class ImageProcessingNode:
|
||||||
"""Base class for image processing nodes that operate on IMAGE_LIST."""
|
"""Base class for image processing nodes that operate on IMAGE_LIST."""
|
||||||
|
|
||||||
CATEGORY = "image/transforms"
|
CATEGORY = "dataset/image"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("IMAGE_LIST",)
|
RETURN_TYPES = ("IMAGE_LIST",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
@ -269,7 +268,7 @@ class ImageProcessingNode:
|
|||||||
class TextProcessingNode:
|
class TextProcessingNode:
|
||||||
"""Base class for text processing nodes that operate on TEXT_LIST."""
|
"""Base class for text processing nodes that operate on TEXT_LIST."""
|
||||||
|
|
||||||
CATEGORY = "text/transforms"
|
CATEGORY = "dataset/text"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("TEXT_LIST",)
|
RETURN_TYPES = ("TEXT_LIST",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
@ -518,7 +517,7 @@ class ShuffleDatasetNode(ImageProcessingNode):
|
|||||||
class ShuffleImageTextDatasetNode:
|
class ShuffleImageTextDatasetNode:
|
||||||
"""Special node that shuffles both images and texts together (doesn't inherit from base class)."""
|
"""Special node that shuffles both images and texts together (doesn't inherit from base class)."""
|
||||||
|
|
||||||
CATEGORY = "image/transforms"
|
CATEGORY = "dataset/image"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
RETURN_TYPES = ("IMAGE_LIST", "TEXT_LIST")
|
RETURN_TYPES = ("IMAGE_LIST", "TEXT_LIST")
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
@ -620,6 +619,201 @@ class StripWhitespaceNode(TextProcessingNode):
|
|||||||
return [text.strip() for text in texts]
|
return [text.strip() for text in texts]
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Training Dataset Nodes ==========
|
||||||
|
|
||||||
|
class MakeTrainingDataset:
|
||||||
|
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"images": ("IMAGE_LIST", {"tooltip": "List of images to encode."}),
|
||||||
|
"vae": ("VAE", {"tooltip": "VAE model for encoding images to latents."}),
|
||||||
|
"clip": ("CLIP", {"tooltip": "CLIP model for encoding text to conditioning."}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"texts": ("TEXT_LIST", {"tooltip": "List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string)."}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT_LIST", "CONDITIONING")
|
||||||
|
RETURN_NAMES = ("latents", "conditioning")
|
||||||
|
FUNCTION = "make_dataset"
|
||||||
|
CATEGORY = "dataset"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Encodes images with VAE and texts with CLIP to create a training dataset. Returns a list of latents and a flat conditioning list."
|
||||||
|
|
||||||
|
def make_dataset(self, images, vae, clip, texts=None):
|
||||||
|
# Handle text list
|
||||||
|
num_images = len(images)
|
||||||
|
|
||||||
|
if texts is None or len(texts) == 0:
|
||||||
|
# Treat as [""] for unconditional training
|
||||||
|
texts = [""]
|
||||||
|
|
||||||
|
if len(texts) == 1 and num_images > 1:
|
||||||
|
# Repeat single text for all images
|
||||||
|
texts = texts * num_images
|
||||||
|
elif len(texts) != num_images:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of texts ({len(texts)}) does not match number of images ({num_images}). "
|
||||||
|
f"Text list should have length {num_images}, 1, or 0."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encode images with VAE
|
||||||
|
logging.info(f"Encoding {num_images} images with VAE...")
|
||||||
|
latents = []
|
||||||
|
for img_tensor in images:
|
||||||
|
# img_tensor is [1, H, W, 3]
|
||||||
|
t = vae.encode(img_tensor[:,:,:,:3])
|
||||||
|
latents.append({"samples": t})
|
||||||
|
|
||||||
|
# Encode texts with CLIP
|
||||||
|
logging.info(f"Encoding {len(texts)} texts with CLIP...")
|
||||||
|
conditions = []
|
||||||
|
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
|
||||||
|
for text in texts:
|
||||||
|
if text == "":
|
||||||
|
conditions.extend(empty_cond)
|
||||||
|
else:
|
||||||
|
tokens = clip.tokenize(text)
|
||||||
|
cond = clip.encode_from_tokens_scheduled(tokens)
|
||||||
|
conditions.extend(cond)
|
||||||
|
|
||||||
|
logging.info(f"Created dataset with {len(latents)} latents and {len(conditions)} conditions.")
|
||||||
|
return (latents, conditions)
|
||||||
|
|
||||||
|
|
||||||
|
class SaveTrainingDataset:
|
||||||
|
"""Save encoded training dataset (latents + conditioning) to disk."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"latents": ("LATENT_LIST", {"tooltip": "List of latent tensors from MakeTrainingDataset."}),
|
||||||
|
"conditioning": ("CONDITIONING", {"tooltip": "Conditioning list from MakeTrainingDataset."}),
|
||||||
|
"folder_name": ("STRING", {"default": "training_dataset", "tooltip": "Name of folder to save dataset (inside output directory)."}),
|
||||||
|
"shard_size": ("INT", {"default": 1000, "min": 1, "max": 100000, "step": 1, "tooltip": "Number of samples per shard file."}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
FUNCTION = "save_dataset"
|
||||||
|
CATEGORY = "dataset"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Saves a training dataset to disk in sharded pickle files. Each shard contains (latent, conditioning) pairs."
|
||||||
|
|
||||||
|
def save_dataset(self, latents, conditioning, folder_name, shard_size):
|
||||||
|
# Validate lengths match
|
||||||
|
if len(latents) != len(conditioning):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). "
|
||||||
|
f"Something went wrong in dataset preparation."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Prepare data pairs
|
||||||
|
num_samples = len(latents)
|
||||||
|
num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division
|
||||||
|
|
||||||
|
logging.info(f"Saving {num_samples} samples to {num_shards} shards in {output_dir}...")
|
||||||
|
|
||||||
|
# Save data in shards
|
||||||
|
for shard_idx in range(num_shards):
|
||||||
|
start_idx = shard_idx * shard_size
|
||||||
|
end_idx = min(start_idx + shard_size, num_samples)
|
||||||
|
|
||||||
|
# Get shard data
|
||||||
|
shard_data = {
|
||||||
|
"latents": latents[start_idx:end_idx],
|
||||||
|
"conditioning": conditioning[start_idx:end_idx],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save shard
|
||||||
|
shard_filename = f"shard_{shard_idx:04d}.pkl"
|
||||||
|
shard_path = os.path.join(output_dir, shard_filename)
|
||||||
|
|
||||||
|
with open(shard_path, "wb") as f:
|
||||||
|
pickle.dump(shard_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
logging.info(f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)")
|
||||||
|
|
||||||
|
# Save metadata
|
||||||
|
metadata = {
|
||||||
|
"num_samples": num_samples,
|
||||||
|
"num_shards": num_shards,
|
||||||
|
"shard_size": shard_size,
|
||||||
|
}
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.json")
|
||||||
|
with open(metadata_path, "w") as f:
|
||||||
|
import json
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
|
logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTrainingDataset:
|
||||||
|
"""Load encoded training dataset from disk."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"folder_name": ("STRING", {"default": "training_dataset", "tooltip": "Name of folder containing the saved dataset (inside output directory)."}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT_LIST", "CONDITIONING")
|
||||||
|
RETURN_NAMES = ("latents", "conditioning")
|
||||||
|
FUNCTION = "load_dataset"
|
||||||
|
CATEGORY = "dataset"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
DESCRIPTION = "Loads a training dataset from disk. Returns a list of latents and a flat conditioning list."
|
||||||
|
|
||||||
|
def load_dataset(self, folder_name):
|
||||||
|
# Get dataset directory
|
||||||
|
dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
||||||
|
|
||||||
|
if not os.path.exists(dataset_dir):
|
||||||
|
raise ValueError(f"Dataset directory not found: {dataset_dir}")
|
||||||
|
|
||||||
|
# Find all shard files
|
||||||
|
shard_files = sorted([
|
||||||
|
f for f in os.listdir(dataset_dir)
|
||||||
|
if f.startswith("shard_") and f.endswith(".pkl")
|
||||||
|
])
|
||||||
|
|
||||||
|
if not shard_files:
|
||||||
|
raise ValueError(f"No shard files found in {dataset_dir}")
|
||||||
|
|
||||||
|
logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
|
||||||
|
|
||||||
|
# Load all shards
|
||||||
|
all_latents = []
|
||||||
|
all_conditioning = []
|
||||||
|
|
||||||
|
for shard_file in shard_files:
|
||||||
|
shard_path = os.path.join(dataset_dir, shard_file)
|
||||||
|
|
||||||
|
with open(shard_path, "rb") as f:
|
||||||
|
shard_data = pickle.load(f)
|
||||||
|
|
||||||
|
all_latents.extend(shard_data["latents"])
|
||||||
|
all_conditioning.extend(shard_data["conditioning"])
|
||||||
|
|
||||||
|
logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
|
||||||
|
|
||||||
|
logging.info(f"Successfully loaded {len(all_latents)} samples from {dataset_dir}.")
|
||||||
|
return (all_latents, all_conditioning)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LoadImageDataSetFromFolderNode": LoadImageDataSetFromFolderNode,
|
"LoadImageDataSetFromFolderNode": LoadImageDataSetFromFolderNode,
|
||||||
"LoadImageTextDataSetFromFolderNode": LoadImageTextDataSetFromFolderNode,
|
"LoadImageTextDataSetFromFolderNode": LoadImageTextDataSetFromFolderNode,
|
||||||
@ -645,6 +839,10 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"AddTextSuffixNode": AddTextSuffixNode,
|
"AddTextSuffixNode": AddTextSuffixNode,
|
||||||
"ReplaceTextNode": ReplaceTextNode,
|
"ReplaceTextNode": ReplaceTextNode,
|
||||||
"StripWhitespaceNode": StripWhitespaceNode,
|
"StripWhitespaceNode": StripWhitespaceNode,
|
||||||
|
# Training dataset nodes
|
||||||
|
"MakeTrainingDataset": MakeTrainingDataset,
|
||||||
|
"SaveTrainingDataset": SaveTrainingDataset,
|
||||||
|
"LoadTrainingDataset": LoadTrainingDataset,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -672,4 +870,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"AddTextSuffixNode": "Add Text Suffix",
|
"AddTextSuffixNode": "Add Text Suffix",
|
||||||
"ReplaceTextNode": "Replace Text",
|
"ReplaceTextNode": "Replace Text",
|
||||||
"StripWhitespaceNode": "Strip Whitespace",
|
"StripWhitespaceNode": "Strip Whitespace",
|
||||||
|
# Training dataset nodes
|
||||||
|
"MakeTrainingDataset": "Make Training Dataset",
|
||||||
|
"SaveTrainingDataset": "Save Training Dataset",
|
||||||
|
"LoadTrainingDataset": "Load Training Dataset",
|
||||||
}
|
}
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -2275,6 +2275,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_images.py",
|
"nodes_images.py",
|
||||||
"nodes_video_model.py",
|
"nodes_video_model.py",
|
||||||
"nodes_train.py",
|
"nodes_train.py",
|
||||||
|
"nodes_dataset.py",
|
||||||
"nodes_sag.py",
|
"nodes_sag.py",
|
||||||
"nodes_perpneg.py",
|
"nodes_perpneg.py",
|
||||||
"nodes_stable3d.py",
|
"nodes_stable3d.py",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user