mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 13:50:15 +08:00
move all dataset related implementation to nodes_dataset
This commit is contained in:
parent
992aa2dd8f
commit
650b7b0302
@ -155,233 +155,6 @@ class BiasDiff(torch.nn.Module):
|
|||||||
return self.passive_memory_usage()
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
|
||||||
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
|
|
||||||
"""Utility function to load and process a list of images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_files: List of image filenames
|
|
||||||
input_dir: Base directory containing the images
|
|
||||||
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Batch of processed images
|
|
||||||
"""
|
|
||||||
if not image_files:
|
|
||||||
raise ValueError("No valid images found in input")
|
|
||||||
|
|
||||||
output_images = []
|
|
||||||
|
|
||||||
for file in image_files:
|
|
||||||
image_path = os.path.join(input_dir, file)
|
|
||||||
img = node_helpers.pillow(Image.open, image_path)
|
|
||||||
|
|
||||||
if img.mode == "I":
|
|
||||||
img = img.point(lambda i: i * (1 / 255))
|
|
||||||
img = img.convert("RGB")
|
|
||||||
|
|
||||||
if w is None and h is None:
|
|
||||||
w, h = img.size[0], img.size[1]
|
|
||||||
|
|
||||||
# Resize image to first image
|
|
||||||
if img.size[0] != w or img.size[1] != h:
|
|
||||||
if resize_method == "Stretch":
|
|
||||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
|
||||||
elif resize_method == "Crop":
|
|
||||||
img = img.crop((0, 0, w, h))
|
|
||||||
elif resize_method == "Pad":
|
|
||||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
|
||||||
elif resize_method == "None":
|
|
||||||
raise ValueError(
|
|
||||||
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
|
||||||
)
|
|
||||||
|
|
||||||
img_array = np.array(img).astype(np.float32) / 255.0
|
|
||||||
img_tensor = torch.from_numpy(img_array)[None,]
|
|
||||||
output_images.append(img_tensor)
|
|
||||||
|
|
||||||
return torch.cat(output_images, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadImageSetNode:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"images": (
|
|
||||||
[
|
|
||||||
f
|
|
||||||
for f in os.listdir(folder_paths.get_input_directory())
|
|
||||||
if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"))
|
|
||||||
],
|
|
||||||
{"image_upload": True, "allow_batch": True},
|
|
||||||
)
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"resize_method": (
|
|
||||||
["None", "Stretch", "Crop", "Pad"],
|
|
||||||
{"default": "None"},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
INPUT_IS_LIST = True
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
|
||||||
FUNCTION = "load_images"
|
|
||||||
CATEGORY = "loaders"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def VALIDATE_INPUTS(s, images, resize_method):
|
|
||||||
filenames = images[0] if isinstance(images[0], list) else images
|
|
||||||
|
|
||||||
for image in filenames:
|
|
||||||
if not folder_paths.exists_annotated_filepath(image):
|
|
||||||
return "Invalid image file: {}".format(image)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def load_images(self, input_files, resize_method):
|
|
||||||
input_dir = folder_paths.get_input_directory()
|
|
||||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"]
|
|
||||||
image_files = [
|
|
||||||
f
|
|
||||||
for f in input_files
|
|
||||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
|
||||||
]
|
|
||||||
output_tensor = load_and_process_images(image_files, input_dir, resize_method)
|
|
||||||
return (output_tensor,)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadImageSetFromFolderNode:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."})
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"resize_method": (
|
|
||||||
["None", "Stretch", "Crop", "Pad"],
|
|
||||||
{"default": "None"},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
|
||||||
FUNCTION = "load_images"
|
|
||||||
CATEGORY = "loaders"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
DESCRIPTION = "Loads a batch of images from a directory for training."
|
|
||||||
|
|
||||||
def load_images(self, folder, resize_method):
|
|
||||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
|
||||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
|
||||||
image_files = [
|
|
||||||
f
|
|
||||||
for f in os.listdir(sub_input_dir)
|
|
||||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
|
||||||
]
|
|
||||||
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method)
|
|
||||||
return (output_tensor,)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadImageTextSetFromFolderNode:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}),
|
|
||||||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"resize_method": (
|
|
||||||
["None", "Stretch", "Crop", "Pad"],
|
|
||||||
{"default": "None"},
|
|
||||||
),
|
|
||||||
"width": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"default": -1,
|
|
||||||
"min": -1,
|
|
||||||
"max": 10000,
|
|
||||||
"step": 1,
|
|
||||||
"tooltip": "The width to resize the images to. -1 means use the original width.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"height": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"default": -1,
|
|
||||||
"min": -1,
|
|
||||||
"max": 10000,
|
|
||||||
"step": 1,
|
|
||||||
"tooltip": "The height to resize the images to. -1 means use the original height.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", IO.CONDITIONING,)
|
|
||||||
FUNCTION = "load_images"
|
|
||||||
CATEGORY = "loaders"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
DESCRIPTION = "Loads a batch of images and caption from a directory for training."
|
|
||||||
|
|
||||||
def load_images(self, folder, clip, resize_method, width=None, height=None):
|
|
||||||
if clip is None:
|
|
||||||
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
|
|
||||||
|
|
||||||
logging.info(f"Loading images from folder: {folder}")
|
|
||||||
|
|
||||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
|
||||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
|
||||||
|
|
||||||
image_files = []
|
|
||||||
for item in os.listdir(sub_input_dir):
|
|
||||||
path = os.path.join(sub_input_dir, item)
|
|
||||||
if any(item.lower().endswith(ext) for ext in valid_extensions):
|
|
||||||
image_files.append(path)
|
|
||||||
elif os.path.isdir(path):
|
|
||||||
# Support kohya-ss/sd-scripts folder structure
|
|
||||||
repeat = 1
|
|
||||||
if item.split("_")[0].isdigit():
|
|
||||||
repeat = int(item.split("_")[0])
|
|
||||||
image_files.extend([
|
|
||||||
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
|
|
||||||
] * repeat)
|
|
||||||
|
|
||||||
caption_file_path = [
|
|
||||||
f.replace(os.path.splitext(f)[1], ".txt")
|
|
||||||
for f in image_files
|
|
||||||
]
|
|
||||||
captions = []
|
|
||||||
for caption_file in caption_file_path:
|
|
||||||
caption_path = os.path.join(sub_input_dir, caption_file)
|
|
||||||
if os.path.exists(caption_path):
|
|
||||||
with open(caption_path, "r", encoding="utf-8") as f:
|
|
||||||
caption = f.read().strip()
|
|
||||||
captions.append(caption)
|
|
||||||
else:
|
|
||||||
captions.append("")
|
|
||||||
|
|
||||||
width = width if width != -1 else None
|
|
||||||
height = height if height != -1 else None
|
|
||||||
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
|
|
||||||
|
|
||||||
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
|
||||||
|
|
||||||
logging.info(f"Encoding captions from {sub_input_dir}.")
|
|
||||||
conditions = []
|
|
||||||
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
|
|
||||||
for text in captions:
|
|
||||||
if text == "":
|
|
||||||
conditions.append(empty_cond)
|
|
||||||
tokens = clip.tokenize(text)
|
|
||||||
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
|
|
||||||
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
|
|
||||||
return (output_tensor, conditions)
|
|
||||||
|
|
||||||
|
|
||||||
def draw_loss_graph(loss_map, steps):
|
def draw_loss_graph(loss_map, steps):
|
||||||
width, height = 500, 300
|
width, height = 500, 300
|
||||||
img = Image.new("RGB", (width, height), "white")
|
img = Image.new("RGB", (width, height), "white")
|
||||||
@ -923,8 +696,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"TrainLoraNode": TrainLoraNode,
|
"TrainLoraNode": TrainLoraNode,
|
||||||
"SaveLoRANode": SaveLoRA,
|
"SaveLoRANode": SaveLoRA,
|
||||||
"LoraModelLoader": LoraModelLoader,
|
"LoraModelLoader": LoraModelLoader,
|
||||||
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
|
|
||||||
"LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode,
|
|
||||||
"LossGraphNode": LossGraphNode,
|
"LossGraphNode": LossGraphNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -932,7 +703,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"TrainLoraNode": "Train LoRA",
|
"TrainLoraNode": "Train LoRA",
|
||||||
"SaveLoRANode": "Save LoRA Weights",
|
"SaveLoRANode": "Save LoRA Weights",
|
||||||
"LoraModelLoader": "Load LoRA Model",
|
"LoraModelLoader": "Load LoRA Model",
|
||||||
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
|
|
||||||
"LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder",
|
|
||||||
"LossGraphNode": "Plot Loss Graph",
|
"LossGraphNode": "Plot Loss Graph",
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user