make training node to work with our dataset system

This commit is contained in:
Kohaku-Blueleaf 2025-11-03 00:33:54 +08:00
parent 7119b278d8
commit 32b44f5d1c
2 changed files with 61 additions and 14 deletions

View File

@ -1,6 +1,7 @@
import logging
import os
import pickle
import math
import numpy as np
import torch
@ -332,6 +333,30 @@ class ResizeImagesToSameSizeNode(ImageProcessingNode):
return output_images
class ResizeImagesToPixelCountNode(ImageProcessingNode):
DESCRIPTION = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
@classmethod
def INPUT_TYPES(cls):
base_inputs = super().INPUT_TYPES()
base_inputs["required"]["pixel_count"] = ("INT", {"default": 512 * 512, "min": 1, "max": 8192 * 8192, "step": 1, "tooltip": "Target pixel count."})
base_inputs["required"]["steps"] = ("INT", {"default": 64, "min": 1, "max": 128, "step": 1, "tooltip": "The stepping for resize width/height."})
return base_inputs
def _process(self, images, pixel_count, steps):
output_images = []
for img_tensor in images:
img = self._tensor_to_pil(img_tensor)
w, h = img.size
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
new_w = int(h * pixel_count_ratio / steps) * steps
new_h = int(w * pixel_count_ratio / steps) * steps
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
output_images.append(self._pil_to_tensor(img))
return output_images
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
DESCRIPTION = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
@ -421,7 +446,7 @@ class RandomCropImagesNode(ImageProcessingNode):
return base_inputs
def _process(self, images, width, height, seed):
np.random.seed(seed)
np.random.seed(seed%(2**32-1))
output_images = []
for img_tensor in images:
img = self._tensor_to_pil(img_tensor)
@ -509,7 +534,7 @@ class ShuffleDatasetNode(ImageProcessingNode):
return base_inputs
def _process(self, images, seed):
np.random.seed(seed)
np.random.seed(seed%(2**32-1))
indices = np.random.permutation(len(images))
return [images[i] for i in indices]
@ -534,7 +559,7 @@ class ShuffleImageTextDatasetNode:
}
def process(self, images, texts, seed):
np.random.seed(seed)
np.random.seed(seed%(2**32-1))
indices = np.random.permutation(len(images))
shuffled_images = [images[i] for i in indices]
shuffled_texts = [texts[i] for i in indices]
@ -637,7 +662,7 @@ class MakeTrainingDataset:
},
}
RETURN_TYPES = ("LATENT_LIST", "CONDITIONING")
RETURN_TYPES = ("LATENT", "CONDITIONING")
RETURN_NAMES = ("latents", "conditioning")
FUNCTION = "make_dataset"
CATEGORY = "dataset"
@ -667,7 +692,8 @@ class MakeTrainingDataset:
for img_tensor in images:
# img_tensor is [1, H, W, 3]
t = vae.encode(img_tensor[:,:,:,:3])
latents.append({"samples": t})
latents.append(t)
latents = {"samples": latents}
# Encode texts with CLIP
logging.info(f"Encoding {len(texts)} texts with CLIP...")
@ -681,7 +707,7 @@ class MakeTrainingDataset:
cond = clip.encode_from_tokens_scheduled(tokens)
conditions.extend(cond)
logging.info(f"Created dataset with {len(latents)} latents and {len(conditions)} conditions.")
logging.info(f"Created dataset with {len(latents['samples'])} latents and {len(conditions)} conditions.")
return (latents, conditions)
@ -692,7 +718,7 @@ class SaveTrainingDataset:
def INPUT_TYPES(s):
return {
"required": {
"latents": ("LATENT_LIST", {"tooltip": "List of latent tensors from MakeTrainingDataset."}),
"latents": ("LATENT", {"tooltip": "List of latent tensors from MakeTrainingDataset."}),
"conditioning": ("CONDITIONING", {"tooltip": "Conditioning list from MakeTrainingDataset."}),
"folder_name": ("STRING", {"default": "training_dataset", "tooltip": "Name of folder to save dataset (inside output directory)."}),
"shard_size": ("INT", {"default": 1000, "min": 1, "max": 100000, "step": 1, "tooltip": "Number of samples per shard file."}),
@ -708,7 +734,7 @@ class SaveTrainingDataset:
def save_dataset(self, latents, conditioning, folder_name, shard_size):
# Validate lengths match
if len(latents) != len(conditioning):
if len(latents["samples"]) != len(conditioning):
raise ValueError(
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). "
f"Something went wrong in dataset preparation."
@ -719,7 +745,7 @@ class SaveTrainingDataset:
os.makedirs(output_dir, exist_ok=True)
# Prepare data pairs
num_samples = len(latents)
num_samples = len(latents["samples"])
num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division
logging.info(f"Saving {num_samples} samples to {num_shards} shards in {output_dir}...")
@ -731,7 +757,7 @@ class SaveTrainingDataset:
# Get shard data
shard_data = {
"latents": latents[start_idx:end_idx],
"latents": latents["samples"][start_idx:end_idx],
"conditioning": conditioning[start_idx:end_idx],
}
@ -770,7 +796,7 @@ class LoadTrainingDataset:
},
}
RETURN_TYPES = ("LATENT_LIST", "CONDITIONING")
RETURN_TYPES = ("LATENT", "CONDITIONING")
RETURN_NAMES = ("latents", "conditioning")
FUNCTION = "load_dataset"
CATEGORY = "dataset"
@ -811,7 +837,7 @@ class LoadTrainingDataset:
logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
logging.info(f"Successfully loaded {len(all_latents)} samples from {dataset_dir}.")
return (all_latents, all_conditioning)
return ({"samples": all_latents}, all_conditioning)
NODE_CLASS_MAPPINGS = {
@ -821,6 +847,7 @@ NODE_CLASS_MAPPINGS = {
"SaveImageTextDataSetToFolderNode": SaveImageTextDataSetToFolderNode,
# Image transforms
"ResizeImagesToSameSizeNode": ResizeImagesToSameSizeNode,
"ResizeImagesToPixelCountNode": ResizeImagesToPixelCountNode,
"ResizeImagesByShorterEdgeNode": ResizeImagesByShorterEdgeNode,
"ResizeImagesByLongerEdgeNode": ResizeImagesByLongerEdgeNode,
"CenterCropImagesNode": CenterCropImagesNode,
@ -852,6 +879,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveImageTextDataSetToFolderNode": "Save Simple Image and Text Dataset to Folder",
# Image transforms
"ResizeImagesToSameSizeNode": "Resize Images to Same Size",
"ResizeImagesToPixelCountNode": "Resize Images to Pixel Count",
"ResizeImagesByShorterEdgeNode": "Resize Images by Shorter Edge",
"ResizeImagesByLongerEdgeNode": "Resize Images by Longer Edge",
"CenterCropImagesNode": "Center Crop Images",

View File

@ -558,8 +558,27 @@ class TrainLoraNode:
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
latents = latents["samples"].to(dtype)
num_images = latents.shape[0]
# latents here can be list of different size latent or one large batch
latents = latents["samples"]
if isinstance(latents, list):
all_shapes = set()
latents = [t.to(dtype) for t in latents]
for latent in latents:
all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1:
raise ValueError(
"Different shapes latents are not currently supported"
)
else:
latents = torch.cat(latents, dim=0)
num_images = len(latents)
elif isinstance(latents, list):
latents = latents["samples"].to(dtype)
num_images = latents.shape[0]
else:
logging.error(f"Invalid latents type: {type(latents)}")
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images