simplify composite function

This commit is contained in:
doctorpangloss 2025-07-11 13:52:25 -07:00
parent a464dfa3db
commit 35899ba20e

View File

@ -1,6 +1,9 @@
from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
from typing import NamedTuple, Optional
from jaxtyping import Float
from torch import Tensor
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
from comfy.nodes.package_typing import CustomNode
@ -14,41 +17,83 @@ class CompositeContext(NamedTuple):
height: int
def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None) -> ImageBatch:
def composite(
destination: Float[Tensor, "B C H W"],
source: Float[Tensor, "B C H W"],
x: int,
y: int,
mask: Optional[MaskBatch] = None,
) -> ImageBatch:
"""
Composites a source image onto a destination image at a given (x, y) coordinate
using an optional mask.
This simplified implementation first creates a destination-sized, zero-padded
version of the source image. This canvas is then blended with the destination,
which cleanly handles all boundary conditions (e.g., source placed partially
or fully off-screen).
Args:
destination (ImageBatch): The background image tensor in (B, C, H, W) format.
source (ImageBatch): The foreground image tensor to composite, also (B, C, H, W).
x (int): The x-coordinate (from left) to place the top-left corner of the source.
y (int): The y-coordinate (from top) to place the top-left corner of the source.
mask (Optional[MaskBatch]): An optional luma mask tensor with the same batch size,
height, and width as the destination (B, H, W).
Values of 1.0 indicate using the source pixel, while
0.0 indicates using the destination pixel. If None,
the source is treated as fully opaque.
Returns:
ImageBatch: The resulting composited image tensor.
"""
if not isinstance(destination, torch.Tensor) or not isinstance(source, torch.Tensor):
raise TypeError("destination and source must be torch.Tensor")
if destination.dim() != 4 or source.dim() != 4:
raise ValueError("destination and source must be 4D tensors (B, C, H, W)")
source = source.to(destination.device)
if source.shape[0] != destination.shape[0]:
if destination.shape[0] % source.shape[0] != 0:
raise ValueError(
"Destination batch size must be a multiple of source batch size for broadcasting."
)
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
x, y = int(x), int(y)
left, top = x, y
right, bottom = left + source.shape[3], top + source.shape[2]
dest_b, dest_c, dest_h, dest_w = destination.shape
src_h, src_w = source.shape[2:]
dest_y_start = max(0, y)
dest_y_end = min(dest_h, y + src_h)
dest_x_start = max(0, x)
dest_x_end = min(dest_w, x + src_w)
src_y_start = max(0, -y)
src_y_end = src_y_start + (dest_y_end - dest_y_start)
src_x_start = max(0, -x)
src_x_end = src_x_start + (dest_x_end - dest_x_start)
if dest_y_start >= dest_y_end or dest_x_start >= dest_x_end:
return destination
padded_source = torch.zeros_like(destination)
padded_source[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = source[
:, :, src_y_start:src_y_end, src_x_start:src_x_end
]
if mask is None:
mask = torch.ones_like(source)
final_mask = torch.zeros(dest_b, 1, dest_h, dest_w, device=destination.device)
final_mask[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = 1.0
else:
mask = mask.to(destination.device, copy=True)
if mask.dim() == 2:
mask = mask.unsqueeze(0)
if mask.dim() == 3:
mask = mask.unsqueeze(1)
if mask.shape[0] != source.shape[0]:
mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1)
if mask.dim() != 3 or mask.shape[0] != dest_b or mask.shape[1] != dest_h or mask.shape[2] != dest_w:
raise ValueError(
f"Provided mask shape {mask.shape} is invalid. "
f"Expected (batch, height, width): ({dest_b}, {dest_h}, {dest_w})."
)
final_mask = mask.to(destination.device).unsqueeze(1)
dest_left, dest_top = max(0, left), max(0, top)
dest_right, dest_bottom = min(destination.shape[3], right), min(destination.shape[2], bottom)
blended_image = padded_source * final_mask + destination * (1.0 - final_mask)
if dest_right <= dest_left or dest_bottom <= dest_top: return destination
src_left, src_top = dest_left - left, dest_top - top
src_right, src_bottom = dest_right - left, dest_bottom
destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right]
source_portion = source[:, :, src_top:src_bottom, src_left:src_right]
mask_portion = mask[:, :, dest_top:dest_bottom, dest_left:dest_right]
blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion))
destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion
return destination
return blended_image
def parse_margin(margin_str: str) -> tuple[int, int, int, int]: