diff --git a/nodes.py b/nodes.py index 710cccffe..8f8f90cf6 100644 --- a/nodes.py +++ b/nodes.py @@ -1754,57 +1754,49 @@ class LoadImage: return True -class LoadImageMask: + +class LoadImageMask(LoadImage): ESSENTIALS_CATEGORY = "Image Tools" SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] _color_channels = ["alpha", "red", "green", "blue"] + @classmethod def INPUT_TYPES(s): - input_dir = folder_paths.get_input_directory() - files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] - return {"required": - {"image": (sorted(files), {"image_upload": True}), - "channel": (s._color_channels, ), } - } + types = super().INPUT_TYPES() + return { + "required": { + **types["required"], + "channel": (s._color_channels, ) + } + } CATEGORY = "mask" - RETURN_TYPES = ("MASK",) - FUNCTION = "load_image" - def load_image(self, image, channel): - image_path = folder_paths.get_annotated_filepath(image) - i = node_helpers.pillow(Image.open, image_path) - i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.getbands() != ("R", "G", "B", "A"): - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - i = i.convert("RGBA") - mask = None + FUNCTION = "load_image_mask" + + def load_image_mask(self, image, channel): + image_tensor, mask_tensor = super().load_image(image) c = channel[0].upper() - if c in i.getbands(): - mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 - mask = torch.from_numpy(mask) - if c == 'A': - mask = 1. - mask + + if c == 'A': + return (mask_tensor,) + + channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0) + + if channel_idx < image_tensor.shape[-1]: + return (image_tensor[..., channel_idx].clone(),) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (mask.unsqueeze(0),) + empty_mask = torch.zeros( + image_tensor.shape[:-1], + dtype=image_tensor.dtype, + device=image_tensor.device + ) + return (empty_mask,) @classmethod def IS_CHANGED(s, image, channel): - image_path = folder_paths.get_annotated_filepath(image) - m = hashlib.sha256() - with open(image_path, 'rb') as f: - m.update(f.read()) - return m.digest().hex() - - @classmethod - def VALIDATE_INPUTS(s, image): - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - - return True + return super().IS_CHANGED(image) class LoadImageOutput(LoadImage):