ComfyUI/comfy_extras/nodes/nodes_inpainting.py
2025-07-16 15:52:36 -07:00

244 lines
9.7 KiB
Python

from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
from comfy.nodes.package_typing import CustomNode
from ..constants.resolutions import RESOLUTION_MAP, SD_RESOLUTIONS, RESOLUTION_NAMES
class CompositeContext(NamedTuple):
x: int
y: int
width: int
height: int
def composite(
destination: Float[Tensor, "B C H W"],
source: Float[Tensor, "B C H W"],
x: int,
y: int,
mask: Optional[MaskBatch] = None,
) -> ImageBatch:
"""
Composites a source image onto a destination image at a given (x, y) coordinate
using an optional mask.
This simplified implementation first creates a destination-sized, zero-padded
version of the source image. This canvas is then blended with the destination,
which cleanly handles all boundary conditions (e.g., source placed partially
or fully off-screen).
Args:
destination (ImageBatch): The background image tensor in (B, C, H, W) format.
source (ImageBatch): The foreground image tensor to composite, also (B, C, H, W).
x (int): The x-coordinate (from left) to place the top-left corner of the source.
y (int): The y-coordinate (from top) to place the top-left corner of the source.
mask (Optional[MaskBatch]): An optional luma mask tensor with the same batch size,
height, and width as the destination (B, H, W).
Values of 1.0 indicate using the source pixel, while
0.0 indicates using the destination pixel. If None,
the source is treated as fully opaque.
Returns:
ImageBatch: The resulting composited image tensor.
"""
if not isinstance(destination, torch.Tensor) or not isinstance(source, torch.Tensor):
raise TypeError("destination and source must be torch.Tensor")
if destination.dim() != 4 or source.dim() != 4:
raise ValueError("destination and source must be 4D tensors (B, C, H, W)")
source = source.to(destination.device)
if source.shape[0] != destination.shape[0]:
if destination.shape[0] % source.shape[0] != 0:
raise ValueError(
"Destination batch size must be a multiple of source batch size for broadcasting."
)
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
dest_b, dest_c, dest_h, dest_w = destination.shape
src_h, src_w = source.shape[2:]
dest_y_start = max(0, y)
dest_y_end = min(dest_h, y + src_h)
dest_x_start = max(0, x)
dest_x_end = min(dest_w, x + src_w)
src_y_start = max(0, -y)
src_y_end = src_y_start + (dest_y_end - dest_y_start)
src_x_start = max(0, -x)
src_x_end = src_x_start + (dest_x_end - dest_x_start)
if dest_y_start >= dest_y_end or dest_x_start >= dest_x_end:
return destination
padded_source = torch.zeros_like(destination)
padded_source[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = source[
:, :, src_y_start:src_y_end, src_x_start:src_x_end
]
if mask is None:
final_mask = torch.zeros(dest_b, 1, dest_h, dest_w, device=destination.device)
final_mask[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = 1.0
else:
if mask.dim() != 3 or mask.shape[0] != dest_b or mask.shape[1] != dest_h or mask.shape[2] != dest_w:
raise ValueError(
f"Provided mask shape {mask.shape} is invalid. "
f"Expected (batch, height, width): ({dest_b}, {dest_h}, {dest_w})."
)
final_mask = mask.to(destination.device).unsqueeze(1)
blended_image = padded_source * final_mask + destination * (1.0 - final_mask)
return blended_image
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
parts = [int(p) for p in margin_str.strip().split()]
match len(parts):
case 1:
return parts[0], parts[0], parts[0], parts[0]
case 2:
return parts[0], parts[1], parts[0], parts[1]
case 3:
return parts[0], parts[1], parts[2], parts[1]
case 4:
return parts[0], parts[1], parts[2], parts[3]
case _:
raise ValueError("Invalid margin format.")
class CropAndFitInpaintToDiffusionSize(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",), "mask": ("MASK",),
"resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}),
"margin": ("STRING", {"default": "64"}),
}
}
RETURN_TYPES = ("IMAGE", "MASK", "COMPOSITE_CONTEXT")
RETURN_NAMES = ("image", "mask", "composite_context")
FUNCTION = "crop_and_fit"
CATEGORY = "inpaint"
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str):
if mask.max() == 0.0:
raise ValueError("Mask is empty (all black).")
mask_coords = torch.nonzero(mask)
if mask_coords.numel() == 0:
raise ValueError("Mask is empty (all black).")
y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2]
y_min, x_min = y_coords.min().item(), x_coords.min().item()
y_max, x_max = y_coords.max().item(), x_coords.max().item()
top_m, right_m, bottom_m, left_m = parse_margin(margin)
x_start_expanded, y_start_expanded = x_min - left_m, y_min - top_m
x_end_expanded, y_end_expanded = x_max + 1 + right_m, y_max + 1 + bottom_m
img_h, img_w = image.shape[1:3]
clamped_x_start = max(0, x_start_expanded)
clamped_y_start = max(0, y_start_expanded)
clamped_x_end = min(img_w, x_end_expanded)
clamped_y_end = min(img_h, y_end_expanded)
initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start
if initial_w <= 0 or initial_h <= 0:
raise ValueError("Cropped area has zero dimension.")
supported_resolutions = RESOLUTION_MAP.get(resolutions, SD_RESOLUTIONS)
diffs = [(abs(res[0] / res[1] - (initial_w / initial_h)), res) for res in supported_resolutions]
target_res = min(diffs, key=lambda x: x[0])[1]
target_ar = target_res[0] / target_res[1]
current_ar = initial_w / initial_h
if current_ar > target_ar:
cover_w, cover_h = float(initial_w), float(initial_w) / target_ar
else:
cover_h, cover_w = float(initial_h), float(initial_h) * target_ar
if cover_w > img_w or cover_h > img_h:
final_x, final_y, final_w, final_h = 0, 0, img_w, img_h
full_img_ar = img_w / img_h
diffs_full = [(abs(res[0] / res[1] - full_img_ar), res) for res in supported_resolutions]
target_res = min(diffs_full, key=lambda x: x[0])[1]
else:
center_x = clamped_x_start + initial_w / 2
center_y = clamped_y_start + initial_h / 2
final_x, final_y = center_x - cover_w / 2, center_y - cover_h / 2
final_w, final_h = cover_w, cover_h
if final_x < 0:
final_x = 0
if final_y < 0:
final_y = 0
if final_x + final_w > img_w:
final_x = img_w - final_w
if final_y + final_h > img_h:
final_y = img_h - final_h
final_x, final_y, final_w, final_h = int(final_x), int(final_y), int(final_w), int(final_h)
cropped_image = image[:, final_y:final_y + final_h, final_x:final_x + final_w]
cropped_mask = mask[:, final_y:final_y + final_h, final_x:final_x + final_w]
resized_image = F.interpolate(cropped_image.permute(0, 3, 1, 2), size=(target_res[1], target_res[0]), mode="bilinear", align_corners=False).permute(0, 2, 3, 1)
resized_mask = F.interpolate(cropped_mask.unsqueeze(1), size=(target_res[1], target_res[0]), mode="nearest").squeeze(1)
composite_context = CompositeContext(x=final_x, y=final_y, width=final_w, height=final_h)
return (resized_image, resized_mask, composite_context)
class CompositeCroppedAndFittedInpaintResult(CustomNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"source_image": ("IMAGE",),
"source_mask": ("MASK",),
"inpainted_image": ("IMAGE",),
"composite_context": ("COMPOSITE_CONTEXT",),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "composite_result"
CATEGORY = "inpaint"
def composite_result(self, source_image: ImageBatch, source_mask: MaskBatch, inpainted_image: ImageBatch, composite_context: CompositeContext) -> tuple[ImageBatch]:
context_x, context_y, context_w, context_h = composite_context
resized_inpainted = F.interpolate(
inpainted_image.permute(0, 3, 1, 2),
size=(context_h, context_w),
mode="bilinear", align_corners=False
)
final_image = composite(
destination=source_image.clone().permute(0, 3, 1, 2),
source=resized_inpainted,
x=context_x,
y=context_y,
mask=source_mask
)
return final_image.permute(0, 2, 3, 1),
NODE_CLASS_MAPPINGS = {
"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize,
"CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region",
"CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result",
}