mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
make training node to work with our dataset system
This commit is contained in:
parent
7119b278d8
commit
32b44f5d1c
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -332,6 +333,30 @@ class ResizeImagesToSameSizeNode(ImageProcessingNode):
|
||||
return output_images
|
||||
|
||||
|
||||
class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
||||
DESCRIPTION = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
base_inputs = super().INPUT_TYPES()
|
||||
base_inputs["required"]["pixel_count"] = ("INT", {"default": 512 * 512, "min": 1, "max": 8192 * 8192, "step": 1, "tooltip": "Target pixel count."})
|
||||
base_inputs["required"]["steps"] = ("INT", {"default": 64, "min": 1, "max": 128, "step": 1, "tooltip": "The stepping for resize width/height."})
|
||||
return base_inputs
|
||||
|
||||
def _process(self, images, pixel_count, steps):
|
||||
output_images = []
|
||||
for img_tensor in images:
|
||||
img = self._tensor_to_pil(img_tensor)
|
||||
w, h = img.size
|
||||
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
|
||||
new_w = int(h * pixel_count_ratio / steps) * steps
|
||||
new_h = int(w * pixel_count_ratio / steps) * steps
|
||||
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
output_images.append(self._pil_to_tensor(img))
|
||||
return output_images
|
||||
|
||||
|
||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
DESCRIPTION = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
|
||||
|
||||
@ -421,7 +446,7 @@ class RandomCropImagesNode(ImageProcessingNode):
|
||||
return base_inputs
|
||||
|
||||
def _process(self, images, width, height, seed):
|
||||
np.random.seed(seed)
|
||||
np.random.seed(seed%(2**32-1))
|
||||
output_images = []
|
||||
for img_tensor in images:
|
||||
img = self._tensor_to_pil(img_tensor)
|
||||
@ -509,7 +534,7 @@ class ShuffleDatasetNode(ImageProcessingNode):
|
||||
return base_inputs
|
||||
|
||||
def _process(self, images, seed):
|
||||
np.random.seed(seed)
|
||||
np.random.seed(seed%(2**32-1))
|
||||
indices = np.random.permutation(len(images))
|
||||
return [images[i] for i in indices]
|
||||
|
||||
@ -534,7 +559,7 @@ class ShuffleImageTextDatasetNode:
|
||||
}
|
||||
|
||||
def process(self, images, texts, seed):
|
||||
np.random.seed(seed)
|
||||
np.random.seed(seed%(2**32-1))
|
||||
indices = np.random.permutation(len(images))
|
||||
shuffled_images = [images[i] for i in indices]
|
||||
shuffled_texts = [texts[i] for i in indices]
|
||||
@ -637,7 +662,7 @@ class MakeTrainingDataset:
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT_LIST", "CONDITIONING")
|
||||
RETURN_TYPES = ("LATENT", "CONDITIONING")
|
||||
RETURN_NAMES = ("latents", "conditioning")
|
||||
FUNCTION = "make_dataset"
|
||||
CATEGORY = "dataset"
|
||||
@ -667,7 +692,8 @@ class MakeTrainingDataset:
|
||||
for img_tensor in images:
|
||||
# img_tensor is [1, H, W, 3]
|
||||
t = vae.encode(img_tensor[:,:,:,:3])
|
||||
latents.append({"samples": t})
|
||||
latents.append(t)
|
||||
latents = {"samples": latents}
|
||||
|
||||
# Encode texts with CLIP
|
||||
logging.info(f"Encoding {len(texts)} texts with CLIP...")
|
||||
@ -681,7 +707,7 @@ class MakeTrainingDataset:
|
||||
cond = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditions.extend(cond)
|
||||
|
||||
logging.info(f"Created dataset with {len(latents)} latents and {len(conditions)} conditions.")
|
||||
logging.info(f"Created dataset with {len(latents['samples'])} latents and {len(conditions)} conditions.")
|
||||
return (latents, conditions)
|
||||
|
||||
|
||||
@ -692,7 +718,7 @@ class SaveTrainingDataset:
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"latents": ("LATENT_LIST", {"tooltip": "List of latent tensors from MakeTrainingDataset."}),
|
||||
"latents": ("LATENT", {"tooltip": "List of latent tensors from MakeTrainingDataset."}),
|
||||
"conditioning": ("CONDITIONING", {"tooltip": "Conditioning list from MakeTrainingDataset."}),
|
||||
"folder_name": ("STRING", {"default": "training_dataset", "tooltip": "Name of folder to save dataset (inside output directory)."}),
|
||||
"shard_size": ("INT", {"default": 1000, "min": 1, "max": 100000, "step": 1, "tooltip": "Number of samples per shard file."}),
|
||||
@ -708,7 +734,7 @@ class SaveTrainingDataset:
|
||||
|
||||
def save_dataset(self, latents, conditioning, folder_name, shard_size):
|
||||
# Validate lengths match
|
||||
if len(latents) != len(conditioning):
|
||||
if len(latents["samples"]) != len(conditioning):
|
||||
raise ValueError(
|
||||
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). "
|
||||
f"Something went wrong in dataset preparation."
|
||||
@ -719,7 +745,7 @@ class SaveTrainingDataset:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Prepare data pairs
|
||||
num_samples = len(latents)
|
||||
num_samples = len(latents["samples"])
|
||||
num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division
|
||||
|
||||
logging.info(f"Saving {num_samples} samples to {num_shards} shards in {output_dir}...")
|
||||
@ -731,7 +757,7 @@ class SaveTrainingDataset:
|
||||
|
||||
# Get shard data
|
||||
shard_data = {
|
||||
"latents": latents[start_idx:end_idx],
|
||||
"latents": latents["samples"][start_idx:end_idx],
|
||||
"conditioning": conditioning[start_idx:end_idx],
|
||||
}
|
||||
|
||||
@ -770,7 +796,7 @@ class LoadTrainingDataset:
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT_LIST", "CONDITIONING")
|
||||
RETURN_TYPES = ("LATENT", "CONDITIONING")
|
||||
RETURN_NAMES = ("latents", "conditioning")
|
||||
FUNCTION = "load_dataset"
|
||||
CATEGORY = "dataset"
|
||||
@ -811,7 +837,7 @@ class LoadTrainingDataset:
|
||||
logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
|
||||
|
||||
logging.info(f"Successfully loaded {len(all_latents)} samples from {dataset_dir}.")
|
||||
return (all_latents, all_conditioning)
|
||||
return ({"samples": all_latents}, all_conditioning)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@ -821,6 +847,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"SaveImageTextDataSetToFolderNode": SaveImageTextDataSetToFolderNode,
|
||||
# Image transforms
|
||||
"ResizeImagesToSameSizeNode": ResizeImagesToSameSizeNode,
|
||||
"ResizeImagesToPixelCountNode": ResizeImagesToPixelCountNode,
|
||||
"ResizeImagesByShorterEdgeNode": ResizeImagesByShorterEdgeNode,
|
||||
"ResizeImagesByLongerEdgeNode": ResizeImagesByLongerEdgeNode,
|
||||
"CenterCropImagesNode": CenterCropImagesNode,
|
||||
@ -852,6 +879,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveImageTextDataSetToFolderNode": "Save Simple Image and Text Dataset to Folder",
|
||||
# Image transforms
|
||||
"ResizeImagesToSameSizeNode": "Resize Images to Same Size",
|
||||
"ResizeImagesToPixelCountNode": "Resize Images to Pixel Count",
|
||||
"ResizeImagesByShorterEdgeNode": "Resize Images by Shorter Edge",
|
||||
"ResizeImagesByLongerEdgeNode": "Resize Images by Longer Edge",
|
||||
"CenterCropImagesNode": "Center Crop Images",
|
||||
|
||||
@ -558,8 +558,27 @@ class TrainLoraNode:
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
latents = latents["samples"].to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
# latents here can be list of different size latent or one large batch
|
||||
latents = latents["samples"]
|
||||
if isinstance(latents, list):
|
||||
all_shapes = set()
|
||||
latents = [t.to(dtype) for t in latents]
|
||||
for latent in latents:
|
||||
all_shapes.add(latent.shape)
|
||||
logging.info(f"Latent shapes: {all_shapes}")
|
||||
if len(all_shapes) > 1:
|
||||
raise ValueError(
|
||||
"Different shapes latents are not currently supported"
|
||||
)
|
||||
else:
|
||||
latents = torch.cat(latents, dim=0)
|
||||
num_images = len(latents)
|
||||
elif isinstance(latents, list):
|
||||
latents = latents["samples"].to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
else:
|
||||
logging.error(f"Invalid latents type: {type(latents)}")
|
||||
|
||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
if len(positive) == 1 and num_images > 1:
|
||||
positive = positive * num_images
|
||||
|
||||
Loading…
Reference in New Issue
Block a user