mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 21:00:16 +08:00
simplify composite function
This commit is contained in:
parent
a464dfa3db
commit
35899ba20e
@ -1,6 +1,9 @@
|
|||||||
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import NamedTuple, Optional
|
from jaxtyping import Float
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
|
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
|
||||||
from comfy.nodes.package_typing import CustomNode
|
from comfy.nodes.package_typing import CustomNode
|
||||||
@ -14,41 +17,83 @@ class CompositeContext(NamedTuple):
|
|||||||
height: int
|
height: int
|
||||||
|
|
||||||
|
|
||||||
def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None) -> ImageBatch:
|
def composite(
|
||||||
|
destination: Float[Tensor, "B C H W"],
|
||||||
|
source: Float[Tensor, "B C H W"],
|
||||||
|
x: int,
|
||||||
|
y: int,
|
||||||
|
mask: Optional[MaskBatch] = None,
|
||||||
|
) -> ImageBatch:
|
||||||
|
"""
|
||||||
|
Composites a source image onto a destination image at a given (x, y) coordinate
|
||||||
|
using an optional mask.
|
||||||
|
|
||||||
|
This simplified implementation first creates a destination-sized, zero-padded
|
||||||
|
version of the source image. This canvas is then blended with the destination,
|
||||||
|
which cleanly handles all boundary conditions (e.g., source placed partially
|
||||||
|
or fully off-screen).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (ImageBatch): The background image tensor in (B, C, H, W) format.
|
||||||
|
source (ImageBatch): The foreground image tensor to composite, also (B, C, H, W).
|
||||||
|
x (int): The x-coordinate (from left) to place the top-left corner of the source.
|
||||||
|
y (int): The y-coordinate (from top) to place the top-left corner of the source.
|
||||||
|
mask (Optional[MaskBatch]): An optional luma mask tensor with the same batch size,
|
||||||
|
height, and width as the destination (B, H, W).
|
||||||
|
Values of 1.0 indicate using the source pixel, while
|
||||||
|
0.0 indicates using the destination pixel. If None,
|
||||||
|
the source is treated as fully opaque.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageBatch: The resulting composited image tensor.
|
||||||
|
"""
|
||||||
|
if not isinstance(destination, torch.Tensor) or not isinstance(source, torch.Tensor):
|
||||||
|
raise TypeError("destination and source must be torch.Tensor")
|
||||||
|
if destination.dim() != 4 or source.dim() != 4:
|
||||||
|
raise ValueError("destination and source must be 4D tensors (B, C, H, W)")
|
||||||
|
|
||||||
source = source.to(destination.device)
|
source = source.to(destination.device)
|
||||||
|
|
||||||
if source.shape[0] != destination.shape[0]:
|
if source.shape[0] != destination.shape[0]:
|
||||||
|
if destination.shape[0] % source.shape[0] != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Destination batch size must be a multiple of source batch size for broadcasting."
|
||||||
|
)
|
||||||
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
|
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
|
||||||
|
|
||||||
x, y = int(x), int(y)
|
dest_b, dest_c, dest_h, dest_w = destination.shape
|
||||||
left, top = x, y
|
src_h, src_w = source.shape[2:]
|
||||||
right, bottom = left + source.shape[3], top + source.shape[2]
|
|
||||||
|
|
||||||
|
dest_y_start = max(0, y)
|
||||||
|
dest_y_end = min(dest_h, y + src_h)
|
||||||
|
dest_x_start = max(0, x)
|
||||||
|
dest_x_end = min(dest_w, x + src_w)
|
||||||
|
|
||||||
|
src_y_start = max(0, -y)
|
||||||
|
src_y_end = src_y_start + (dest_y_end - dest_y_start)
|
||||||
|
src_x_start = max(0, -x)
|
||||||
|
src_x_end = src_x_start + (dest_x_end - dest_x_start)
|
||||||
|
|
||||||
|
if dest_y_start >= dest_y_end or dest_x_start >= dest_x_end:
|
||||||
|
return destination
|
||||||
|
padded_source = torch.zeros_like(destination)
|
||||||
|
padded_source[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = source[
|
||||||
|
:, :, src_y_start:src_y_end, src_x_start:src_x_end
|
||||||
|
]
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = torch.ones_like(source)
|
final_mask = torch.zeros(dest_b, 1, dest_h, dest_w, device=destination.device)
|
||||||
|
final_mask[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = 1.0
|
||||||
else:
|
else:
|
||||||
mask = mask.to(destination.device, copy=True)
|
if mask.dim() != 3 or mask.shape[0] != dest_b or mask.shape[1] != dest_h or mask.shape[2] != dest_w:
|
||||||
if mask.dim() == 2:
|
raise ValueError(
|
||||||
mask = mask.unsqueeze(0)
|
f"Provided mask shape {mask.shape} is invalid. "
|
||||||
if mask.dim() == 3:
|
f"Expected (batch, height, width): ({dest_b}, {dest_h}, {dest_w})."
|
||||||
mask = mask.unsqueeze(1)
|
)
|
||||||
if mask.shape[0] != source.shape[0]:
|
final_mask = mask.to(destination.device).unsqueeze(1)
|
||||||
mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1)
|
|
||||||
|
|
||||||
dest_left, dest_top = max(0, left), max(0, top)
|
blended_image = padded_source * final_mask + destination * (1.0 - final_mask)
|
||||||
dest_right, dest_bottom = min(destination.shape[3], right), min(destination.shape[2], bottom)
|
|
||||||
|
|
||||||
if dest_right <= dest_left or dest_bottom <= dest_top: return destination
|
return blended_image
|
||||||
|
|
||||||
src_left, src_top = dest_left - left, dest_top - top
|
|
||||||
src_right, src_bottom = dest_right - left, dest_bottom
|
|
||||||
|
|
||||||
destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right]
|
|
||||||
source_portion = source[:, :, src_top:src_bottom, src_left:src_right]
|
|
||||||
mask_portion = mask[:, :, dest_top:dest_bottom, dest_left:dest_right]
|
|
||||||
|
|
||||||
blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion))
|
|
||||||
destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion
|
|
||||||
return destination
|
|
||||||
|
|
||||||
|
|
||||||
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
|
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user