diff --git a/comfy_extras/silver_custom.py b/comfy_extras/silver_custom.py index 03a073c07..c72e99923 100644 --- a/comfy_extras/silver_custom.py +++ b/comfy_extras/silver_custom.py @@ -1,5 +1,8 @@ +import PIL +import numpy as np import cv2 import torch +from PIL.Image import Image class ExpandImageMask: @classmethod @@ -12,7 +15,7 @@ class ExpandImageMask: CATEGORY = "mask" - RETURN_TYPES = ("IMAGE", "MASK", ) + RETURN_TYPES = ("MASK", ) FUNCTION = "image_to_mask_image" def image_to_mask_image(self, images): @@ -40,23 +43,34 @@ class ExpandImageMask: # Threshold binary mask image again im_bw = cv2.threshold(im_bw, thresh, 255, cv2.THRESH_BINARY)[1] - # Invert binary mask image - # im_bw = cv2.bitwise_not(im_bw) + # Convert binary mask image to 3-channel RGB image + mask_image_rgb = np.zeros_like(i) + mask_image_rgb[:, :, 0] = im_bw + mask_image_rgb[:, :, 1] = im_bw + mask_image_rgb[:, :, 2] = im_bw + pil_image = PIL.Image.fromarray(np.uint8(mask_image_rgb)) - # Convert binary mask image to PyTorch tensor - img = torch.from_numpy(im_bw).unsqueeze(0).float() + # create a new alpha channel with all pixels set to 255 (fully opaque) + alpha = PIL.Image.new('L', pil_image.size, 255) + + # iterate over each pixel and set the alpha channel to 0 if the RGB values are white + for x in range(pil_image.width): + for y in range(pil_image.height): + if pil_image.getpixel((x, y)) == (255, 255, 255): + alpha.putpixel((x, y), 0) + + # merge the alpha channel with the original image + pil_image.putalpha(alpha) # Append mask image tensor to list - mask_images.append(img) + mask_images.append(1. - torch.from_numpy(np.array(pil_image.getchannel('A')).astype(np.float32) / 255.0)) - # Stack list of mask image tensors into a single tensor - mask_images_tensor = torch.cat(mask_images) - - # Return tuple of mask images and single mask image - single_mask_image = mask_images_tensor[0, :, :] - return mask_images_tensor, single_mask_image + return mask_images NODE_CLASS_MAPPINGS = { "ExpandImageMask": ExpandImageMask } +NODE_DISPLAY_NAME_MAPPINGS = { + "ExpandImageMask": "Expand Image Mask" +}