Add VideoLatentCompositeMasked and RGBMaskToLatentMask nodes

This commit is contained in:
David Lee 2026-05-02 15:23:42 -04:00 committed by GitHub
parent 3e3ed8cc2a
commit 800bf842a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,6 +46,110 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
def video_latent_composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False):
# destination/source shape: [B, C, F, H, W]
source = source.to(destination.device)
# 1. Spatial Resizing for Source
if resize_source:
# size=(Frames, Height, Width). We keep source's F, but match destination's H, W
target_size = (source.shape[2], destination.shape[3], destination.shape[4])
source = torch.nn.functional.interpolate(
source,
size=target_size,
mode="trilinear",
align_corners=False
)
# 2. Coordinate Scaling
x_latent = x // multiplier
y_latent = y // multiplier
# 3. Mask Processing (Input: [F, H, W])
if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.to(destination.device, copy=True)
# Convert [F, H, W] -> [1, 1, F, H, W]
# This allows it to broadcast across any Batch or Channel in 'source'
mask = mask.unsqueeze(0).unsqueeze(0)
# Resize mask spatially, preserving its frame count
# size=(mask_frames, source_height, source_width)
mask_target_size = (mask.shape[2], source.shape[3], source.shape[4])
mask = torch.nn.functional.interpolate(
mask,
size=mask_target_size,
mode="trilinear",
align_corners=False
)
# 4. Dimension Calculations for Spatial Slicing
dst_h, dst_w = destination.shape[3], destination.shape[4]
src_h, src_w = source.shape[3], source.shape[4]
# Calculate visible overlap region
visible_h = max(0, min(y_latent + src_h, dst_h) - max(0, y_latent))
visible_w = max(0, min(x_latent + src_w, dst_w) - max(0, x_latent))
if visible_h <= 0 or visible_w <= 0:
return destination
# Determine slicing offsets
src_top = max(0, -y_latent)
src_left = max(0, -x_latent)
dst_top = max(0, y_latent)
dst_left = max(0, x_latent)
# 5. Slicing and Blending
# destination/source/mask are now all 5D: [B, C, F, H, W]
# We slice only the H and W dimensions (indices 3 and 4)
m = mask[:, :, :, src_top:src_top+visible_h, src_left:src_left+visible_w]
s = source[:, :, :, src_top:src_top+visible_h, src_left:src_left+visible_w]
d = destination[:, :, :, dst_top:dst_top+visible_h, dst_left:dst_left+visible_w]
# Combine using the mask
destination[:, :, :, dst_top:dst_top+visible_h, dst_left:dst_left+visible_w] = (m * s) + ((1.0 - m) * d)
return destination
def convert_rgb_mask_to_latent_mask(
mask: torch.Tensor,
k: int,
spatial_downsample_h: int,
spatial_downsample_w: int
) -> torch.Tensor:
"""
Converts [T, H, W] mask to [T_latent, H_latent, W_latent].
Handles non-square spatial downsampling.
"""
# 1. Temporal Sampling
# Select first frame and every k-th frame thereafter
mask0 = mask[0:1]
mask1 = mask[1::k]
sampled = torch.cat([mask0, mask1], dim=0) # [T_latent, H, W]
# 2. Prepare for Spatial Interpolation
# Shape: [Batch=1, Channels=1, Depth=T_latent, Height=H, Width=W]
sampled = sampled.unsqueeze(0).unsqueeze(0)
# 3. Calculate New Spatial Dimensions
h_latent = sampled.shape[-2] // spatial_downsample_h
w_latent = sampled.shape[-1] // spatial_downsample_w
# 4. Interpolate
# We maintain the temporal count (sampled.shape[2])
# but resize H and W independently
pooled = torch.nn.functional.interpolate(
sampled,
size=(sampled.shape[2], h_latent, w_latent),
mode="nearest"
)
# 5. Return to [T_latent, H_latent, W_latent]
return pooled.squeeze(0).squeeze(0)
class LatentCompositeMasked(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -74,6 +178,40 @@ class LatentCompositeMasked(IO.ComfyNode):
composite = execute # TODO: remove
class VideoLatentCompositeMasked(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VideoLatentCompositeMasked",
search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"],
category="latent",
inputs=[
IO.Latent.Input("destination"),
IO.Latent.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
IO.Boolean.Input("resize_source", default=False),
IO.Mask.Input("mask", optional=True),
],
outputs=[IO.Latent.Output()],
)
@classmethod
def execute(cls, destination, source, x, y, resize_source, mask=None) -> IO.NodeOutput:
output = destination.copy()
# Ensure we work on a copy of the samples to remain non-destructive
dst_samples = destination["samples"].clone()
src_samples = source["samples"]
output["samples"] = video_latent_composite(
dst_samples,
src_samples,
x, y,
mask,
multiplier=8,
resize_source=resize_source
)
return IO.NodeOutput(output)
class ImageCompositeMasked(IO.ComfyNode):
@classmethod
@ -398,6 +536,28 @@ class ThresholdMask(IO.ComfyNode):
image_to_mask = execute # TODO: remove
class RGBMaskToLatentMask(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RGBMasktoLatentMask",
search_aliases=["rgb mask to latent mask", "rgb mask", "latent mask"],
description="Helpful for applying masks to video latents if the VAE uses spatial downsampling.",
category="latent",
inputs=[
IO.Mask.Input("mask", optional=False),
IO.Vae.Input("vae", optional=False),
],
outputs=[IO.Mask.Output()],
)
@classmethod
def execute(cls, mask, vae) -> IO.NodeOutput:
# Ensure we work on a copy of the mask to remain non-destructive
mask_copy = mask.clone()
downscale_ratio = vae.downscale_ratio
k = (mask.shape[0] - 1) // (downscale_ratio[0](mask.shape[0]) - 1) if (downscale_ratio[0](mask.shape[0]) - 1) > 1 else 1
return IO.NodeOutput(convert_rgb_mask_to_latent_mask(mask_copy, k, spatial_downsample_h = downscale_ratio[1], spatial_downsample_w = downscale_ratio[2]))
# Mask Preview - original implement from
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
@ -428,6 +588,7 @@ class MaskExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
LatentCompositeMasked,
VideoLatentCompositeMasked,
ImageCompositeMasked,
MaskToImage,
ImageToMask,
@ -439,6 +600,7 @@ class MaskExtension(ComfyExtension):
FeatherMask,
GrowMask,
ThresholdMask,
RGBMaskToLatentMask,
MaskPreview,
]