diff --git a/comfy_extras/nodes/nodes_inpainting.py b/comfy_extras/nodes/nodes_inpainting.py index 4964bbb33..aa7399ba9 100644 --- a/comfy_extras/nodes/nodes_inpainting.py +++ b/comfy_extras/nodes/nodes_inpainting.py @@ -1,6 +1,9 @@ +from typing import NamedTuple, Optional + import torch import torch.nn.functional as F -from typing import NamedTuple, Optional +from jaxtyping import Float +from torch import Tensor from comfy.component_model.tensor_types import MaskBatch, ImageBatch from comfy.nodes.package_typing import CustomNode @@ -14,41 +17,83 @@ class CompositeContext(NamedTuple): height: int -def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None) -> ImageBatch: +def composite( + destination: Float[Tensor, "B C H W"], + source: Float[Tensor, "B C H W"], + x: int, + y: int, + mask: Optional[MaskBatch] = None, +) -> ImageBatch: + """ + Composites a source image onto a destination image at a given (x, y) coordinate + using an optional mask. + + This simplified implementation first creates a destination-sized, zero-padded + version of the source image. This canvas is then blended with the destination, + which cleanly handles all boundary conditions (e.g., source placed partially + or fully off-screen). + + Args: + destination (ImageBatch): The background image tensor in (B, C, H, W) format. + source (ImageBatch): The foreground image tensor to composite, also (B, C, H, W). + x (int): The x-coordinate (from left) to place the top-left corner of the source. + y (int): The y-coordinate (from top) to place the top-left corner of the source. + mask (Optional[MaskBatch]): An optional luma mask tensor with the same batch size, + height, and width as the destination (B, H, W). + Values of 1.0 indicate using the source pixel, while + 0.0 indicates using the destination pixel. If None, + the source is treated as fully opaque. + + Returns: + ImageBatch: The resulting composited image tensor. + """ + if not isinstance(destination, torch.Tensor) or not isinstance(source, torch.Tensor): + raise TypeError("destination and source must be torch.Tensor") + if destination.dim() != 4 or source.dim() != 4: + raise ValueError("destination and source must be 4D tensors (B, C, H, W)") + source = source.to(destination.device) + if source.shape[0] != destination.shape[0]: + if destination.shape[0] % source.shape[0] != 0: + raise ValueError( + "Destination batch size must be a multiple of source batch size for broadcasting." + ) source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1) - x, y = int(x), int(y) - left, top = x, y - right, bottom = left + source.shape[3], top + source.shape[2] + dest_b, dest_c, dest_h, dest_w = destination.shape + src_h, src_w = source.shape[2:] + dest_y_start = max(0, y) + dest_y_end = min(dest_h, y + src_h) + dest_x_start = max(0, x) + dest_x_end = min(dest_w, x + src_w) + + src_y_start = max(0, -y) + src_y_end = src_y_start + (dest_y_end - dest_y_start) + src_x_start = max(0, -x) + src_x_end = src_x_start + (dest_x_end - dest_x_start) + + if dest_y_start >= dest_y_end or dest_x_start >= dest_x_end: + return destination + padded_source = torch.zeros_like(destination) + padded_source[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = source[ + :, :, src_y_start:src_y_end, src_x_start:src_x_end + ] if mask is None: - mask = torch.ones_like(source) + final_mask = torch.zeros(dest_b, 1, dest_h, dest_w, device=destination.device) + final_mask[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = 1.0 else: - mask = mask.to(destination.device, copy=True) - if mask.dim() == 2: - mask = mask.unsqueeze(0) - if mask.dim() == 3: - mask = mask.unsqueeze(1) - if mask.shape[0] != source.shape[0]: - mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1) + if mask.dim() != 3 or mask.shape[0] != dest_b or mask.shape[1] != dest_h or mask.shape[2] != dest_w: + raise ValueError( + f"Provided mask shape {mask.shape} is invalid. " + f"Expected (batch, height, width): ({dest_b}, {dest_h}, {dest_w})." + ) + final_mask = mask.to(destination.device).unsqueeze(1) - dest_left, dest_top = max(0, left), max(0, top) - dest_right, dest_bottom = min(destination.shape[3], right), min(destination.shape[2], bottom) + blended_image = padded_source * final_mask + destination * (1.0 - final_mask) - if dest_right <= dest_left or dest_bottom <= dest_top: return destination - - src_left, src_top = dest_left - left, dest_top - top - src_right, src_bottom = dest_right - left, dest_bottom - - destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right] - source_portion = source[:, :, src_top:src_bottom, src_left:src_right] - mask_portion = mask[:, :, dest_top:dest_bottom, dest_left:dest_right] - - blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion)) - destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion - return destination + return blended_image def parse_margin(margin_str: str) -> tuple[int, int, int, int]: