mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 15:20:51 +08:00
simplify composite function
This commit is contained in:
parent
a464dfa3db
commit
35899ba20e
@ -1,6 +1,9 @@
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import NamedTuple, Optional
|
||||
from jaxtyping import Float
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
@ -14,41 +17,83 @@ class CompositeContext(NamedTuple):
|
||||
height: int
|
||||
|
||||
|
||||
def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None) -> ImageBatch:
|
||||
def composite(
|
||||
destination: Float[Tensor, "B C H W"],
|
||||
source: Float[Tensor, "B C H W"],
|
||||
x: int,
|
||||
y: int,
|
||||
mask: Optional[MaskBatch] = None,
|
||||
) -> ImageBatch:
|
||||
"""
|
||||
Composites a source image onto a destination image at a given (x, y) coordinate
|
||||
using an optional mask.
|
||||
|
||||
This simplified implementation first creates a destination-sized, zero-padded
|
||||
version of the source image. This canvas is then blended with the destination,
|
||||
which cleanly handles all boundary conditions (e.g., source placed partially
|
||||
or fully off-screen).
|
||||
|
||||
Args:
|
||||
destination (ImageBatch): The background image tensor in (B, C, H, W) format.
|
||||
source (ImageBatch): The foreground image tensor to composite, also (B, C, H, W).
|
||||
x (int): The x-coordinate (from left) to place the top-left corner of the source.
|
||||
y (int): The y-coordinate (from top) to place the top-left corner of the source.
|
||||
mask (Optional[MaskBatch]): An optional luma mask tensor with the same batch size,
|
||||
height, and width as the destination (B, H, W).
|
||||
Values of 1.0 indicate using the source pixel, while
|
||||
0.0 indicates using the destination pixel. If None,
|
||||
the source is treated as fully opaque.
|
||||
|
||||
Returns:
|
||||
ImageBatch: The resulting composited image tensor.
|
||||
"""
|
||||
if not isinstance(destination, torch.Tensor) or not isinstance(source, torch.Tensor):
|
||||
raise TypeError("destination and source must be torch.Tensor")
|
||||
if destination.dim() != 4 or source.dim() != 4:
|
||||
raise ValueError("destination and source must be 4D tensors (B, C, H, W)")
|
||||
|
||||
source = source.to(destination.device)
|
||||
|
||||
if source.shape[0] != destination.shape[0]:
|
||||
if destination.shape[0] % source.shape[0] != 0:
|
||||
raise ValueError(
|
||||
"Destination batch size must be a multiple of source batch size for broadcasting."
|
||||
)
|
||||
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
|
||||
|
||||
x, y = int(x), int(y)
|
||||
left, top = x, y
|
||||
right, bottom = left + source.shape[3], top + source.shape[2]
|
||||
dest_b, dest_c, dest_h, dest_w = destination.shape
|
||||
src_h, src_w = source.shape[2:]
|
||||
|
||||
dest_y_start = max(0, y)
|
||||
dest_y_end = min(dest_h, y + src_h)
|
||||
dest_x_start = max(0, x)
|
||||
dest_x_end = min(dest_w, x + src_w)
|
||||
|
||||
src_y_start = max(0, -y)
|
||||
src_y_end = src_y_start + (dest_y_end - dest_y_start)
|
||||
src_x_start = max(0, -x)
|
||||
src_x_end = src_x_start + (dest_x_end - dest_x_start)
|
||||
|
||||
if dest_y_start >= dest_y_end or dest_x_start >= dest_x_end:
|
||||
return destination
|
||||
padded_source = torch.zeros_like(destination)
|
||||
padded_source[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = source[
|
||||
:, :, src_y_start:src_y_end, src_x_start:src_x_end
|
||||
]
|
||||
if mask is None:
|
||||
mask = torch.ones_like(source)
|
||||
final_mask = torch.zeros(dest_b, 1, dest_h, dest_w, device=destination.device)
|
||||
final_mask[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = 1.0
|
||||
else:
|
||||
mask = mask.to(destination.device, copy=True)
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
if mask.shape[0] != source.shape[0]:
|
||||
mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1)
|
||||
if mask.dim() != 3 or mask.shape[0] != dest_b or mask.shape[1] != dest_h or mask.shape[2] != dest_w:
|
||||
raise ValueError(
|
||||
f"Provided mask shape {mask.shape} is invalid. "
|
||||
f"Expected (batch, height, width): ({dest_b}, {dest_h}, {dest_w})."
|
||||
)
|
||||
final_mask = mask.to(destination.device).unsqueeze(1)
|
||||
|
||||
dest_left, dest_top = max(0, left), max(0, top)
|
||||
dest_right, dest_bottom = min(destination.shape[3], right), min(destination.shape[2], bottom)
|
||||
blended_image = padded_source * final_mask + destination * (1.0 - final_mask)
|
||||
|
||||
if dest_right <= dest_left or dest_bottom <= dest_top: return destination
|
||||
|
||||
src_left, src_top = dest_left - left, dest_top - top
|
||||
src_right, src_bottom = dest_right - left, dest_bottom
|
||||
|
||||
destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right]
|
||||
source_portion = source[:, :, src_top:src_bottom, src_left:src_right]
|
||||
mask_portion = mask[:, :, dest_top:dest_bottom, dest_left:dest_right]
|
||||
|
||||
blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion))
|
||||
destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion
|
||||
return destination
|
||||
return blended_image
|
||||
|
||||
|
||||
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user