diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 14db5c630..5d83aff0c 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -46,74 +46,6 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou destination[..., top:bottom, left:right] = source_portion + destination_portion return destination -def video_latent_composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False): - # destination/source shape: [B, C, F, H, W] - source = source.to(destination.device) - - # 1. Spatial Resizing for Source - if resize_source: - # size=(Frames, Height, Width). We keep source's F, but match destination's H, W - target_size = (source.shape[2], destination.shape[3], destination.shape[4]) - source = torch.nn.functional.interpolate( - source, - size=target_size, - mode="trilinear", - align_corners=False - ) - - # 2. Coordinate Scaling - x_latent = x // multiplier - y_latent = y // multiplier - - # 3. Mask Processing (Input: [F, H, W]) - if mask is None: - mask = torch.ones_like(source) - else: - mask = mask.to(destination.device, copy=True) - - # Convert [F, H, W] -> [1, 1, F, H, W] - # This allows it to broadcast across any Batch or Channel in 'source' - mask = mask.unsqueeze(0).unsqueeze(0) - - # Resize mask spatially, preserving its frame count - # size=(mask_frames, source_height, source_width) - mask_target_size = (mask.shape[2], source.shape[3], source.shape[4]) - mask = torch.nn.functional.interpolate( - mask, - size=mask_target_size, - mode="trilinear", - align_corners=False - ) - - # 4. Dimension Calculations for Spatial Slicing - dst_h, dst_w = destination.shape[3], destination.shape[4] - src_h, src_w = source.shape[3], source.shape[4] - - # Calculate visible overlap region - visible_h = max(0, min(y_latent + src_h, dst_h) - max(0, y_latent)) - visible_w = max(0, min(x_latent + src_w, dst_w) - max(0, x_latent)) - - if visible_h <= 0 or visible_w <= 0: - return destination - - # Determine slicing offsets - src_top = max(0, -y_latent) - src_left = max(0, -x_latent) - dst_top = max(0, y_latent) - dst_left = max(0, x_latent) - - # 5. Slicing and Blending - # destination/source/mask are now all 5D: [B, C, F, H, W] - # We slice only the H and W dimensions (indices 3 and 4) - m = mask[:, :, :, src_top:src_top+visible_h, src_left:src_left+visible_w] - s = source[:, :, :, src_top:src_top+visible_h, src_left:src_left+visible_w] - d = destination[:, :, :, dst_top:dst_top+visible_h, dst_left:dst_left+visible_w] - - # Combine using the mask - destination[:, :, :, dst_top:dst_top+visible_h, dst_left:dst_left+visible_w] = (m * s) + ((1.0 - m) * d) - - return destination - def convert_rgb_mask_to_latent_mask( mask: torch.Tensor, k: int, @@ -177,42 +109,7 @@ class LatentCompositeMasked(IO.ComfyNode): return IO.NodeOutput(output) composite = execute # TODO: remove - -class VideoLatentCompositeMasked(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="VideoLatentCompositeMasked", - search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"], - category="latent", - inputs=[ - IO.Latent.Input("destination"), - IO.Latent.Input("source"), - IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), - IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), - IO.Boolean.Input("resize_source", default=False), - IO.Mask.Input("mask", optional=True), - ], - outputs=[IO.Latent.Output()], - ) - - @classmethod - def execute(cls, destination, source, x, y, resize_source, mask=None) -> IO.NodeOutput: - output = destination.copy() - # Ensure we work on a copy of the samples to remain non-destructive - dst_samples = destination["samples"].clone() - src_samples = source["samples"] - - output["samples"] = video_latent_composite( - dst_samples, - src_samples, - x, y, - mask, - multiplier=8, - resize_source=resize_source - ) - return IO.NodeOutput(output) - + class ImageCompositeMasked(IO.ComfyNode): @classmethod def define_schema(cls): @@ -588,7 +485,6 @@ class MaskExtension(ComfyExtension): async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ LatentCompositeMasked, - VideoLatentCompositeMasked, ImageCompositeMasked, MaskToImage, ImageToMask,