diff --git a/nodes.py b/nodes.py index ad474d3cd..1ca1f2c48 100644 --- a/nodes.py +++ b/nodes.py @@ -1695,44 +1695,41 @@ class PreviewImage(SaveImage): "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] - files = folder_paths.filter_files_content_types(files, ["image"]) + image_paths = [] + for root, _, files in os.walk(input_dir, followlinks=True): + image_files = folder_paths.filter_files_content_types(files, ["image"]) + for image_file in image_files: + path_relative = os.path.relpath(os.path.join(root, image_file), input_dir) + path_relative = path_relative.replace('\\', '/') + image_paths.append(path_relative) return {"required": - {"image": (sorted(files), {"image_upload": True})}, + {"image": (sorted(list(set(image_paths))), {"image_upload": True})}, } CATEGORY = "image" - SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"] - RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) - img = node_helpers.pillow(Image.open, image_path) - output_images = [] output_masks = [] w, h = None, None - for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.mode == 'I': i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") - if len(output_images) == 0: w = image.size[0] h = image.size[1] - if image.size[0] != w or image.size[1] != h: continue - image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] if 'A' in i.getbands(): @@ -1745,19 +1742,15 @@ class LoadImage: mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") output_images.append(image) output_masks.append(mask.unsqueeze(0)) - if img.format == "MPO": break # ignore all frames except the first one for MPO format - if len(output_images) > 1: output_image = torch.cat(output_images, dim=0) output_mask = torch.cat(output_masks, dim=0) else: output_image = output_images[0] output_mask = output_masks[0] - return (output_image, output_mask) - @classmethod def IS_CHANGED(s, image): image_path = folder_paths.get_annotated_filepath(image) @@ -1765,13 +1758,10 @@ class LoadImage: with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() - @classmethod def VALIDATE_INPUTS(s, image): - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - return True + class LoadImageMask: SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]