ComfyUI/comfy_extras/nodes/nodes_inpainting.py
2025-10-06 15:50:41 -07:00

405 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
import comfy.utils
from jaxtyping import Float
from torch import Tensor
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
from comfy.nodes.package_typing import CustomNode
from ..constants.resolutions import RESOLUTION_MAP, SD_RESOLUTIONS, RESOLUTION_NAMES
class CompositeContext(NamedTuple):
x: int
y: int
width: int
height: int
def composite(
destination: Float[Tensor, "B C H W"],
source: Float[Tensor, "B C H W"],
x: int,
y: int,
mask: Optional[MaskBatch] = None,
) -> ImageBatch:
"""
Composites a source image onto a destination image at a given (x, y) coordinate
using an optional mask.
This simplified implementation first creates a destination-sized, zero-padded
version of the source image. This canvas is then blended with the destination,
which cleanly handles all boundary conditions (e.g., source placed partially
or fully off-screen).
Args:
destination (ImageBatch): The background image tensor in (B, C, H, W) format.
source (ImageBatch): The foreground image tensor to composite, also (B, C, H, W).
x (int): The x-coordinate (from left) to place the top-left corner of the source.
y (int): The y-coordinate (from top) to place the top-left corner of the source.
mask (Optional[MaskBatch]): An optional luma mask tensor with the same batch size,
height, and width as the destination (B, H, W).
Values of 1.0 indicate using the source pixel, while
0.0 indicates using the destination pixel. If None,
the source is treated as fully opaque.
Returns:
ImageBatch: The resulting composited image tensor.
"""
if not isinstance(destination, torch.Tensor) or not isinstance(source, torch.Tensor):
raise TypeError("destination and source must be torch.Tensor")
if destination.dim() != 4 or source.dim() != 4:
raise ValueError("destination and source must be 4D tensors (B, C, H, W)")
source = source.to(destination.device)
if source.shape[0] != destination.shape[0]:
if destination.shape[0] % source.shape[0] != 0:
raise ValueError(
"Destination batch size must be a multiple of source batch size for broadcasting."
)
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
dest_b, dest_c, dest_h, dest_w = destination.shape
src_h, src_w = source.shape[2:]
dest_y_start = max(0, y)
dest_y_end = min(dest_h, y + src_h)
dest_x_start = max(0, x)
dest_x_end = min(dest_w, x + src_w)
src_y_start = max(0, -y)
src_y_end = src_y_start + (dest_y_end - dest_y_start)
src_x_start = max(0, -x)
src_x_end = src_x_start + (dest_x_end - dest_x_start)
if dest_y_start >= dest_y_end or dest_x_start >= dest_x_end:
return destination
padded_source = torch.zeros_like(destination)
padded_source[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = source[
:, :, src_y_start:src_y_end, src_x_start:src_x_end
]
if mask is None:
final_mask = torch.zeros(dest_b, 1, dest_h, dest_w, device=destination.device)
final_mask[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = 1.0
else:
if mask.dim() != 3 or mask.shape[0] != dest_b or mask.shape[1] != dest_h or mask.shape[2] != dest_w:
raise ValueError(
f"Provided mask shape {mask.shape} is invalid. "
f"Expected (batch, height, width): ({dest_b}, {dest_h}, {dest_w})."
)
final_mask = mask.to(destination.device).unsqueeze(1)
blended_image = padded_source * final_mask + destination * (1.0 - final_mask)
return blended_image
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
parts = [int(p) for p in margin_str.strip().split()]
match len(parts):
case 1:
return parts[0], parts[0], parts[0], parts[0]
case 2:
return parts[0], parts[1], parts[0], parts[1]
case 3:
return parts[0], parts[1], parts[2], parts[1]
case 4:
return parts[0], parts[1], parts[2], parts[3]
case _:
raise ValueError("Invalid margin format.")
class CropAndFitInpaintToDiffusionSize(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",), "mask": ("MASK",),
"resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}),
"margin": ("STRING", {"default": "64"}),
}
}
RETURN_TYPES = ("IMAGE", "MASK", "COMPOSITE_CONTEXT")
RETURN_NAMES = ("image", "mask", "composite_context")
FUNCTION = "crop_and_fit"
CATEGORY = "inpaint"
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str):
if mask.max() == 0.0:
raise ValueError("Mask is empty (all black).")
mask_coords = torch.nonzero(mask)
if mask_coords.numel() == 0:
raise ValueError("Mask is empty (all black).")
y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2]
y_min, x_min = y_coords.min().item(), x_coords.min().item()
y_max, x_max = y_coords.max().item(), x_coords.max().item()
top_m, right_m, bottom_m, left_m = parse_margin(margin)
x_start_expanded, y_start_expanded = x_min - left_m, y_min - top_m
x_end_expanded, y_end_expanded = x_max + 1 + right_m, y_max + 1 + bottom_m
img_h, img_w = image.shape[1:3]
clamped_x_start = max(0, x_start_expanded)
clamped_y_start = max(0, y_start_expanded)
clamped_x_end = min(img_w, x_end_expanded)
clamped_y_end = min(img_h, y_end_expanded)
initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start
if initial_w <= 0 or initial_h <= 0:
raise ValueError("Cropped area has zero dimension.")
supported_resolutions = RESOLUTION_MAP.get(resolutions, SD_RESOLUTIONS)
diffs = [(abs(res[0] / res[1] - (initial_w / initial_h)), res) for res in supported_resolutions]
target_res = min(diffs, key=lambda x: x[0])[1]
target_ar = target_res[0] / target_res[1]
current_ar = initial_w / initial_h
if current_ar > target_ar:
cover_w, cover_h = float(initial_w), float(initial_w) / target_ar
else:
cover_h, cover_w = float(initial_h), float(initial_h) * target_ar
if cover_w > img_w or cover_h > img_h:
final_x, final_y, final_w, final_h = 0, 0, img_w, img_h
full_img_ar = img_w / img_h
diffs_full = [(abs(res[0] / res[1] - full_img_ar), res) for res in supported_resolutions]
target_res = min(diffs_full, key=lambda x: x[0])[1]
else:
center_x = clamped_x_start + initial_w / 2
center_y = clamped_y_start + initial_h / 2
final_x, final_y = center_x - cover_w / 2, center_y - cover_h / 2
final_w, final_h = cover_w, cover_h
if final_x < 0:
final_x = 0
if final_y < 0:
final_y = 0
if final_x + final_w > img_w:
final_x = img_w - final_w
if final_y + final_h > img_h:
final_y = img_h - final_h
final_x, final_y, final_w, final_h = int(final_x), int(final_y), int(final_w), int(final_h)
cropped_image = image[:, final_y:final_y + final_h, final_x:final_x + final_w]
cropped_mask = mask[:, final_y:final_y + final_h, final_x:final_x + final_w]
resized_image = F.interpolate(cropped_image.permute(0, 3, 1, 2), size=(target_res[1], target_res[0]), mode="bilinear", align_corners=False).permute(0, 2, 3, 1)
resized_mask = F.interpolate(cropped_mask.unsqueeze(1), size=(target_res[1], target_res[0]), mode="nearest").squeeze(1)
composite_context = CompositeContext(x=final_x, y=final_y, width=final_w, height=final_h)
return (resized_image, resized_mask, composite_context)
class CompositeCroppedAndFittedInpaintResult(CustomNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"source_image": ("IMAGE",),
"source_mask": ("MASK",),
"inpainted_image": ("IMAGE",),
"composite_context": ("COMPOSITE_CONTEXT",),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "composite_result"
CATEGORY = "inpaint"
def composite_result(self, source_image: ImageBatch, source_mask: MaskBatch, inpainted_image: ImageBatch, composite_context: CompositeContext) -> tuple[ImageBatch]:
context_x, context_y, context_w, context_h = composite_context
resized_inpainted = F.interpolate(
inpainted_image.permute(0, 3, 1, 2),
size=(context_h, context_w),
mode="bilinear", align_corners=False
)
final_image = composite(
destination=source_image.clone().permute(0, 3, 1, 2),
source=resized_inpainted,
x=context_x,
y=context_y,
mask=source_mask
)
return final_image.permute(0, 2, 3, 1),
class ImageAndMaskResizeNode:
"""
Sherlocked from https://github.com/CY-CHENYUE/ComfyUI-InpaintEasy
MIT License
Copyright (c) 2024 CYCHENYUE
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
DESCRIPTION = "Resize the image and mask simultaneously (from InpaintEasy- 同时调整图片和蒙版的大小)"
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center", "top_left", "top_right", "bottom_left", "bottom_right"]
def __init__(self):
self.type = "ImageMaskResize"
self.output_node = True
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"mask": ("MASK",),
"width": ("INT", {
"default": 512,
"min": 64,
"max": 8192,
"step": 8
}),
"height": ("INT", {
"default": 512,
"min": 64,
"max": 8192,
"step": 8
}),
"resize_method": (s.upscale_methods, {"default": "lanczos"}),
"crop": (s.crop_methods, {"default": "disabled"}),
"mask_blur_radius": ("INT", {
"default": 10,
"min": 0,
"max": 64,
"step": 1
}),
}
}
RETURN_TYPES = ("IMAGE", "MASK",)
RETURN_NAMES = ("image", "mask",)
FUNCTION = "resize_image_and_mask"
CATEGORY = "inpaint"
def resize_image_and_mask(self, image, mask, width, height, resize_method="lanczos", crop="disabled", mask_blur_radius=0):
# 处理宽高为0的情况
if width == 0 and height == 0:
return (image, mask)
# 对于图像的处理
samples = image.movedim(-1, 1) # NHWC -> NCHW
if width == 0:
width = max(1, round(samples.shape[3] * height / samples.shape[2]))
elif height == 0:
height = max(1, round(samples.shape[2] * width / samples.shape[3]))
# 使用 torch.nn.functional 直接进行缩放和裁剪
if crop != "disabled":
old_width = samples.shape[3]
old_height = samples.shape[2]
# 计算缩放比例
scale = max(width / old_width, height / old_height)
scaled_width = int(old_width * scale)
scaled_height = int(old_height * scale)
# 使用 common_upscale 进行缩放
samples = comfy.utils.common_upscale(samples, scaled_width, scaled_height, resize_method, crop="disabled")
# 蒙版始终使用bilinear插值
mask = F.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(scaled_height, scaled_width), mode='bilinear', align_corners=True)
# 计算裁剪位置
crop_x = 0
crop_y = 0
if crop == "center":
crop_x = (scaled_width - width) // 2
crop_y = (scaled_height - height) // 2
elif crop == "top_left":
crop_x = 0
crop_y = 0
elif crop == "top_right":
crop_x = scaled_width - width
crop_y = 0
elif crop == "bottom_left":
crop_x = 0
crop_y = scaled_height - height
elif crop == "bottom_right":
crop_x = scaled_width - width
crop_y = scaled_height - height
elif crop == "random":
crop_x = torch.randint(0, max(1, scaled_width - width), (1,)).item()
crop_y = torch.randint(0, max(1, scaled_height - height), (1,)).item()
# 执行裁剪
samples = samples[:, :, crop_y:crop_y + height, crop_x:crop_x + width]
mask = mask[:, :, crop_y:crop_y + height, crop_x:crop_x + width]
else:
# 直接使用 common_upscale 调整大小
samples = comfy.utils.common_upscale(samples, width, height, resize_method, crop="disabled")
mask = F.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(height, width), mode='bilinear', align_corners=True)
image_resized = samples.movedim(1, -1) # NCHW -> NHWC
mask_resized = mask.squeeze(1) # NCHW -> NHW
# 在返回之前添加高斯模糊处理
if mask_blur_radius > 0:
# 创建高斯核
kernel_size = mask_blur_radius * 2 + 1
x = torch.arange(kernel_size, dtype=torch.float32, device=mask_resized.device)
x = x - (kernel_size - 1) / 2
gaussian = torch.exp(-(x ** 2) / (2 * (mask_blur_radius / 3) ** 2))
gaussian = gaussian / gaussian.sum()
# 将kernel转换为2D
gaussian_2d = gaussian.view(1, -1) * gaussian.view(-1, 1)
gaussian_2d = gaussian_2d.view(1, 1, kernel_size, kernel_size)
# 应用高斯模糊
mask_for_blur = mask_resized.unsqueeze(1) # Add channel dimension
# 对边界进行padding使用reflect模式避免边缘问题
padding = kernel_size // 2
mask_padded = F.pad(mask_for_blur, (padding, padding, padding, padding), mode='reflect')
mask_resized = F.conv2d(mask_padded, gaussian_2d.to(mask_resized.device), padding=0).squeeze(1)
# 确保值在0-1范围内
mask_resized = torch.clamp(mask_resized, 0, 1)
return (image_resized, mask_resized)
NODE_CLASS_MAPPINGS = {
"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize,
"CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult,
"ImageAndMaskResizeNode": ImageAndMaskResizeNode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region",
"CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result",
"ImageAndMaskResizeNode": "Image and Mask Resize"
}