move all dataset related implementation to nodes_dataset

This commit is contained in:
Kohaku-Blueleaf 2025-11-04 12:51:40 +08:00
parent 992aa2dd8f
commit 650b7b0302

View File

@ -155,233 +155,6 @@ class BiasDiff(torch.nn.Module):
return self.passive_memory_usage()
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
"""Utility function to load and process a list of images.
Args:
image_files: List of image filenames
input_dir: Base directory containing the images
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
Returns:
torch.Tensor: Batch of processed images
"""
if not image_files:
raise ValueError("No valid images found in input")
output_images = []
for file in image_files:
image_path = os.path.join(input_dir, file)
img = node_helpers.pillow(Image.open, image_path)
if img.mode == "I":
img = img.point(lambda i: i * (1 / 255))
img = img.convert("RGB")
if w is None and h is None:
w, h = img.size[0], img.size[1]
# Resize image to first image
if img.size[0] != w or img.size[1] != h:
if resize_method == "Stretch":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "Crop":
img = img.crop((0, 0, w, h))
elif resize_method == "Pad":
img = img.resize((w, h), Image.Resampling.LANCZOS)
elif resize_method == "None":
raise ValueError(
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
)
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)[None,]
output_images.append(img_tensor)
return torch.cat(output_images, dim=0)
class LoadImageSetNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": (
[
f
for f in os.listdir(folder_paths.get_input_directory())
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
],
{"image_upload": True, "allow_batch": True},
)
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
INPUT_IS_LIST = True
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
@classmethod
def VALIDATE_INPUTS(s, images, resize_method):
filenames = images[0] if isinstance(images[0], list) else images
for image in filenames:
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
def load_images(self, input_files, resize_method):
input_dir = folder_paths.get_input_directory()
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
image_files = [
f
for f in input_files
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
return (output_tensor,)
class LoadImageSetFromFolderNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images from a directory for training."
def load_images(self, folder, resize_method):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = [
f
for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in valid_extensions)
]
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
return (output_tensor,)
class LoadImageTextSetFromFolderNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}),
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
"width": (
IO.INT,
{
"default": -1,
"min": -1,
"max": 10000,
"step": 1,
"tooltip": "The width to resize the images to. -1 means use the original width.",
},
),
"height": (
IO.INT,
{
"default": -1,
"min": -1,
"max": 10000,
"step": 1,
"tooltip": "The height to resize the images to. -1 means use the original height.",
},
)
},
}
RETURN_TYPES = ("IMAGE", IO.CONDITIONING,)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images and caption from a directory for training."
def load_images(self, folder, clip, resize_method, width=None, height=None):
if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
logging.info(f"Loading images from folder: {folder}")
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
image_files = []
for item in os.listdir(sub_input_dir):
path = os.path.join(sub_input_dir, item)
if any(item.lower().endswith(ext) for ext in valid_extensions):
image_files.append(path)
elif os.path.isdir(path):
# Support kohya-ss/sd-scripts folder structure
repeat = 1
if item.split("_")[0].isdigit():
repeat = int(item.split("_")[0])
image_files.extend([
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
] * repeat)
caption_file_path = [
f.replace(os.path.splitext(f)[1], ".txt")
for f in image_files
]
captions = []
for caption_file in caption_file_path:
caption_path = os.path.join(sub_input_dir, caption_file)
if os.path.exists(caption_path):
with open(caption_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
captions.append(caption)
else:
captions.append("")
width = width if width != -1 else None
height = height if height != -1 else None
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
logging.info(f"Encoding captions from {sub_input_dir}.")
conditions = []
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
for text in captions:
if text == "":
conditions.append(empty_cond)
tokens = clip.tokenize(text)
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
return (output_tensor, conditions)
def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
@ -923,8 +696,6 @@ NODE_CLASS_MAPPINGS = {
"TrainLoraNode": TrainLoraNode,
"SaveLoRANode": SaveLoRA,
"LoraModelLoader": LoraModelLoader,
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
"LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode,
"LossGraphNode": LossGraphNode,
}
@ -932,7 +703,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"TrainLoraNode": "Train LoRA",
"SaveLoRANode": "Save LoRA Weights",
"LoraModelLoader": "Load LoRA Model",
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
"LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder",
"LossGraphNode": "Plot Loss Graph",
}