diff --git a/nodes.py b/nodes.py index d695190bb..cd2b870b7 100644 --- a/nodes.py +++ b/nodes.py @@ -1712,25 +1712,32 @@ class LoadImage: } CATEGORY = "image" + SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"] RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" - SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"] def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) + img = node_helpers.pillow(Image.open, image_path) + output_images = [] output_masks = [] w, h = None, None + for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) + if i.mode == 'I': i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") + if len(output_images) == 0: w = image.size[0] h = image.size[1] + if image.size[0] != w or image.size[1] != h: continue + image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] if 'A' in i.getbands(): @@ -1743,15 +1750,19 @@ class LoadImage: mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") output_images.append(image) output_masks.append(mask.unsqueeze(0)) + if img.format == "MPO": break # ignore all frames except the first one for MPO format + if len(output_images) > 1: output_image = torch.cat(output_images, dim=0) output_mask = torch.cat(output_masks, dim=0) else: output_image = output_images[0] output_mask = output_masks[0] + return (output_image, output_mask) + @classmethod def IS_CHANGED(s, image): image_path = folder_paths.get_annotated_filepath(image) @@ -1759,10 +1770,13 @@ class LoadImage: with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod def VALIDATE_INPUTS(s, image): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + return True - class LoadImageMask: SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]