From 800bf842a5e83ed5760eed881bbce96deb7b59f5 Mon Sep 17 00:00:00 2001 From: David Lee <47388918+Pizzawookiee@users.noreply.github.com> Date: Sat, 2 May 2026 15:23:42 -0400 Subject: [PATCH] Add VideoLatentCompositeMasked and RGBMaskToLatentMask nodes --- comfy_extras/nodes_mask.py | 162 +++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 8ca947718..14db5c630 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -46,6 +46,110 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou destination[..., top:bottom, left:right] = source_portion + destination_portion return destination +def video_latent_composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False): + # destination/source shape: [B, C, F, H, W] + source = source.to(destination.device) + + # 1. Spatial Resizing for Source + if resize_source: + # size=(Frames, Height, Width). We keep source's F, but match destination's H, W + target_size = (source.shape[2], destination.shape[3], destination.shape[4]) + source = torch.nn.functional.interpolate( + source, + size=target_size, + mode="trilinear", + align_corners=False + ) + + # 2. Coordinate Scaling + x_latent = x // multiplier + y_latent = y // multiplier + + # 3. Mask Processing (Input: [F, H, W]) + if mask is None: + mask = torch.ones_like(source) + else: + mask = mask.to(destination.device, copy=True) + + # Convert [F, H, W] -> [1, 1, F, H, W] + # This allows it to broadcast across any Batch or Channel in 'source' + mask = mask.unsqueeze(0).unsqueeze(0) + + # Resize mask spatially, preserving its frame count + # size=(mask_frames, source_height, source_width) + mask_target_size = (mask.shape[2], source.shape[3], source.shape[4]) + mask = torch.nn.functional.interpolate( + mask, + size=mask_target_size, + mode="trilinear", + align_corners=False + ) + + # 4. Dimension Calculations for Spatial Slicing + dst_h, dst_w = destination.shape[3], destination.shape[4] + src_h, src_w = source.shape[3], source.shape[4] + + # Calculate visible overlap region + visible_h = max(0, min(y_latent + src_h, dst_h) - max(0, y_latent)) + visible_w = max(0, min(x_latent + src_w, dst_w) - max(0, x_latent)) + + if visible_h <= 0 or visible_w <= 0: + return destination + + # Determine slicing offsets + src_top = max(0, -y_latent) + src_left = max(0, -x_latent) + dst_top = max(0, y_latent) + dst_left = max(0, x_latent) + + # 5. Slicing and Blending + # destination/source/mask are now all 5D: [B, C, F, H, W] + # We slice only the H and W dimensions (indices 3 and 4) + m = mask[:, :, :, src_top:src_top+visible_h, src_left:src_left+visible_w] + s = source[:, :, :, src_top:src_top+visible_h, src_left:src_left+visible_w] + d = destination[:, :, :, dst_top:dst_top+visible_h, dst_left:dst_left+visible_w] + + # Combine using the mask + destination[:, :, :, dst_top:dst_top+visible_h, dst_left:dst_left+visible_w] = (m * s) + ((1.0 - m) * d) + + return destination + +def convert_rgb_mask_to_latent_mask( + mask: torch.Tensor, + k: int, + spatial_downsample_h: int, + spatial_downsample_w: int +) -> torch.Tensor: + """ + Converts [T, H, W] mask to [T_latent, H_latent, W_latent]. + Handles non-square spatial downsampling. + """ + # 1. Temporal Sampling + # Select first frame and every k-th frame thereafter + mask0 = mask[0:1] + mask1 = mask[1::k] + sampled = torch.cat([mask0, mask1], dim=0) # [T_latent, H, W] + + # 2. Prepare for Spatial Interpolation + # Shape: [Batch=1, Channels=1, Depth=T_latent, Height=H, Width=W] + sampled = sampled.unsqueeze(0).unsqueeze(0) + + # 3. Calculate New Spatial Dimensions + h_latent = sampled.shape[-2] // spatial_downsample_h + w_latent = sampled.shape[-1] // spatial_downsample_w + + # 4. Interpolate + # We maintain the temporal count (sampled.shape[2]) + # but resize H and W independently + pooled = torch.nn.functional.interpolate( + sampled, + size=(sampled.shape[2], h_latent, w_latent), + mode="nearest" + ) + + # 5. Return to [T_latent, H_latent, W_latent] + return pooled.squeeze(0).squeeze(0) + class LatentCompositeMasked(IO.ComfyNode): @classmethod def define_schema(cls): @@ -74,6 +178,40 @@ class LatentCompositeMasked(IO.ComfyNode): composite = execute # TODO: remove +class VideoLatentCompositeMasked(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VideoLatentCompositeMasked", + search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"], + category="latent", + inputs=[ + IO.Latent.Input("destination"), + IO.Latent.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Latent.Output()], + ) + + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask=None) -> IO.NodeOutput: + output = destination.copy() + # Ensure we work on a copy of the samples to remain non-destructive + dst_samples = destination["samples"].clone() + src_samples = source["samples"] + + output["samples"] = video_latent_composite( + dst_samples, + src_samples, + x, y, + mask, + multiplier=8, + resize_source=resize_source + ) + return IO.NodeOutput(output) class ImageCompositeMasked(IO.ComfyNode): @classmethod @@ -398,6 +536,28 @@ class ThresholdMask(IO.ComfyNode): image_to_mask = execute # TODO: remove +class RGBMaskToLatentMask(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RGBMasktoLatentMask", + search_aliases=["rgb mask to latent mask", "rgb mask", "latent mask"], + description="Helpful for applying masks to video latents if the VAE uses spatial downsampling.", + category="latent", + inputs=[ + IO.Mask.Input("mask", optional=False), + IO.Vae.Input("vae", optional=False), + ], + outputs=[IO.Mask.Output()], + ) + + @classmethod + def execute(cls, mask, vae) -> IO.NodeOutput: + # Ensure we work on a copy of the mask to remain non-destructive + mask_copy = mask.clone() + downscale_ratio = vae.downscale_ratio + k = (mask.shape[0] - 1) // (downscale_ratio[0](mask.shape[0]) - 1) if (downscale_ratio[0](mask.shape[0]) - 1) > 1 else 1 + return IO.NodeOutput(convert_rgb_mask_to_latent_mask(mask_copy, k, spatial_downsample_h = downscale_ratio[1], spatial_downsample_w = downscale_ratio[2])) # Mask Preview - original implement from # https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81 @@ -428,6 +588,7 @@ class MaskExtension(ComfyExtension): async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ LatentCompositeMasked, + VideoLatentCompositeMasked, ImageCompositeMasked, MaskToImage, ImageToMask, @@ -439,6 +600,7 @@ class MaskExtension(ComfyExtension): FeatherMask, GrowMask, ThresholdMask, + RGBMaskToLatentMask, MaskPreview, ]