diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 32b60afda..cd8111683 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -39,10 +39,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou inverse_mask = torch.ones_like(mask) - mask - source_portion = mask * source[..., :visible_height, :visible_width] - destination_portion = inverse_mask * destination[..., top:bottom, left:right] + source_rgb = source[:, :3, :visible_height, :visible_width] + dest_slice = destination[..., top:bottom, left:right] + + if destination.shape[1] == 4: + if torch.max(dest_slice) == 0: + destination[:, :3, top:bottom, left:right] = source_rgb + destination[:, 3:4, top:bottom, left:right] = mask + else: + destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3]) + destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4]) + else: + source_portion = mask * source_rgb + destination_portion = inverse_mask * dest_slice + destination[..., top:bottom, left:right] = source_portion + destination_portion - destination[..., top:bottom, left:right] = source_portion + destination_portion return destination class LatentCompositeMasked(IO.ComfyNode): @@ -82,18 +93,23 @@ class ImageCompositeMasked(IO.ComfyNode): search_aliases=["paste image", "overlay", "layer"], category="image", inputs=[ - IO.Image.Input("destination"), IO.Image.Input("source"), IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), IO.Boolean.Input("resize_source", default=False), + IO.Image.Input("destination", optional=True), IO.Mask.Input("mask", optional=True), ], outputs=[IO.Image.Output()], ) @classmethod - def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: + def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput: + if destination is None: # transparent rgba + B, H, W, C = source.shape + destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device) + if C == 3: + source = torch.nn.functional.pad(source, (0, 1), value=1.0) destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) @@ -397,33 +413,6 @@ class ClipVisionToMask(IO.ComfyNode): clip_vision_to_mask = execute -class ConcatMask(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ConcatMask", - search_aliases=["add mask", "concat mask", "merge mask"], - category="mask", - inputs=[ - IO.Mask.Input("mask"), - IO.Image.Input("image"), - ], - outputs=[IO.Image.Output("rgba_image"), IO.Mask.Output("input_mask")], - ) - @classmethod - def execute(cls, mask, image): - if image.shape[-1] == 3: - image = image.movedim(-1, 1) - target_h, target_w = image.shape[2], image.shape[3] - if mask.shape[-2:] != (target_h, target_w): - mask = torch.nn.functional.interpolate( - mask, size=(target_h, target_w), mode='bicubic', align_corners=False - ) - rgba = torch.cat([image, mask], dim = 1) - return IO.NodeOutput(rgba.movedim(1, -1), mask) - - concat_mask = execute - class ThresholdMask(IO.ComfyNode): @classmethod def define_schema(cls): @@ -487,7 +476,6 @@ class MaskExtension(ComfyExtension): GrowMask, ThresholdMask, MaskPreview, - ConcatMask, ClipVisionToMask ]