From 8c16b98008de1bb9751adf6b9a31cbdd5996a75e Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sun, 9 Apr 2023 22:36:39 -0600 Subject: [PATCH] Image composite node --- comfy_extras/nodes_post_processing.py | 38 +++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ceb0b92e1..163139fef 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -435,7 +435,43 @@ class Merge: return (result,) +class Composite: + def __init__(self): + pass + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image_a": ("IMAGE",), + "image_b": ("IMAGE",), + "mask": ("MASK",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "composite" + + CATEGORY = "image/postprocessing" + + def composite(self, image_a: torch.Tensor, image_b: torch.Tensor, mask: torch.Tensor): + batch_size, height, width, _ = image_a.shape + result = torch.zeros_like(image_a) + + for b in range(batch_size): + img_a = (image_a[b] * 255).to(torch.uint8).numpy() + img_b = (image_b[b] * 255).to(torch.uint8).numpy() + img_mask = (mask * 255).to(torch.uint8).numpy() + pil_image_a = Image.fromarray(img_a, mode='RGB') + pil_image_b = Image.fromarray(img_b, mode='RGB') + pil_image_mask = Image.fromarray(img_mask, mode='L') + + output_image = Image.composite(pil_image_a, pil_image_b, pil_image_mask) + + output_array = torch.tensor(np.array(output_image.convert("RGB"))).float() / 255 + result[b] = output_array + + return (result,) NODE_CLASS_MAPPINGS = { "ImageBlend": Blend, @@ -447,6 +483,7 @@ NODE_CLASS_MAPPINGS = { "ImageGetChannel": GetChannel, "ImageSplit": Split, "ImageMerge": Merge, + "ImageComposite": Composite, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -459,4 +496,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageGetChannel": "Extract Channel", "ImageSplit": "Split Channels", "ImageMerge": "Merge Channels", + "ImageComposite": "Composite Images", }