mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
405 lines
16 KiB
Python
405 lines
16 KiB
Python
from typing import NamedTuple, Optional
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import comfy.utils
|
||
|
||
from jaxtyping import Float
|
||
from torch import Tensor
|
||
|
||
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
|
||
from comfy.nodes.package_typing import CustomNode
|
||
from ..constants.resolutions import RESOLUTION_MAP, SD_RESOLUTIONS, RESOLUTION_NAMES
|
||
|
||
|
||
class CompositeContext(NamedTuple):
|
||
x: int
|
||
y: int
|
||
width: int
|
||
height: int
|
||
|
||
|
||
def composite(
|
||
destination: Float[Tensor, "B C H W"],
|
||
source: Float[Tensor, "B C H W"],
|
||
x: int,
|
||
y: int,
|
||
mask: Optional[MaskBatch] = None,
|
||
) -> ImageBatch:
|
||
"""
|
||
Composites a source image onto a destination image at a given (x, y) coordinate
|
||
using an optional mask.
|
||
|
||
This simplified implementation first creates a destination-sized, zero-padded
|
||
version of the source image. This canvas is then blended with the destination,
|
||
which cleanly handles all boundary conditions (e.g., source placed partially
|
||
or fully off-screen).
|
||
|
||
Args:
|
||
destination (ImageBatch): The background image tensor in (B, C, H, W) format.
|
||
source (ImageBatch): The foreground image tensor to composite, also (B, C, H, W).
|
||
x (int): The x-coordinate (from left) to place the top-left corner of the source.
|
||
y (int): The y-coordinate (from top) to place the top-left corner of the source.
|
||
mask (Optional[MaskBatch]): An optional luma mask tensor with the same batch size,
|
||
height, and width as the destination (B, H, W).
|
||
Values of 1.0 indicate using the source pixel, while
|
||
0.0 indicates using the destination pixel. If None,
|
||
the source is treated as fully opaque.
|
||
|
||
Returns:
|
||
ImageBatch: The resulting composited image tensor.
|
||
"""
|
||
if not isinstance(destination, torch.Tensor) or not isinstance(source, torch.Tensor):
|
||
raise TypeError("destination and source must be torch.Tensor")
|
||
if destination.dim() != 4 or source.dim() != 4:
|
||
raise ValueError("destination and source must be 4D tensors (B, C, H, W)")
|
||
|
||
source = source.to(destination.device)
|
||
|
||
if source.shape[0] != destination.shape[0]:
|
||
if destination.shape[0] % source.shape[0] != 0:
|
||
raise ValueError(
|
||
"Destination batch size must be a multiple of source batch size for broadcasting."
|
||
)
|
||
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
|
||
|
||
dest_b, dest_c, dest_h, dest_w = destination.shape
|
||
src_h, src_w = source.shape[2:]
|
||
|
||
dest_y_start = max(0, y)
|
||
dest_y_end = min(dest_h, y + src_h)
|
||
dest_x_start = max(0, x)
|
||
dest_x_end = min(dest_w, x + src_w)
|
||
|
||
src_y_start = max(0, -y)
|
||
src_y_end = src_y_start + (dest_y_end - dest_y_start)
|
||
src_x_start = max(0, -x)
|
||
src_x_end = src_x_start + (dest_x_end - dest_x_start)
|
||
|
||
if dest_y_start >= dest_y_end or dest_x_start >= dest_x_end:
|
||
return destination
|
||
padded_source = torch.zeros_like(destination)
|
||
padded_source[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = source[
|
||
:, :, src_y_start:src_y_end, src_x_start:src_x_end
|
||
]
|
||
if mask is None:
|
||
final_mask = torch.zeros(dest_b, 1, dest_h, dest_w, device=destination.device)
|
||
final_mask[:, :, dest_y_start:dest_y_end, dest_x_start:dest_x_end] = 1.0
|
||
else:
|
||
if mask.dim() != 3 or mask.shape[0] != dest_b or mask.shape[1] != dest_h or mask.shape[2] != dest_w:
|
||
raise ValueError(
|
||
f"Provided mask shape {mask.shape} is invalid. "
|
||
f"Expected (batch, height, width): ({dest_b}, {dest_h}, {dest_w})."
|
||
)
|
||
final_mask = mask.to(destination.device).unsqueeze(1)
|
||
|
||
blended_image = padded_source * final_mask + destination * (1.0 - final_mask)
|
||
|
||
return blended_image
|
||
|
||
|
||
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
|
||
parts = [int(p) for p in margin_str.strip().split()]
|
||
match len(parts):
|
||
case 1:
|
||
return parts[0], parts[0], parts[0], parts[0]
|
||
case 2:
|
||
return parts[0], parts[1], parts[0], parts[1]
|
||
case 3:
|
||
return parts[0], parts[1], parts[2], parts[1]
|
||
case 4:
|
||
return parts[0], parts[1], parts[2], parts[3]
|
||
case _:
|
||
raise ValueError("Invalid margin format.")
|
||
|
||
|
||
class CropAndFitInpaintToDiffusionSize(CustomNode):
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",), "mask": ("MASK",),
|
||
"resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}),
|
||
"margin": ("STRING", {"default": "64"}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE", "MASK", "COMPOSITE_CONTEXT")
|
||
RETURN_NAMES = ("image", "mask", "composite_context")
|
||
FUNCTION = "crop_and_fit"
|
||
CATEGORY = "inpaint"
|
||
|
||
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str):
|
||
if mask.max() == 0.0:
|
||
raise ValueError("Mask is empty (all black).")
|
||
|
||
mask_coords = torch.nonzero(mask)
|
||
if mask_coords.numel() == 0:
|
||
raise ValueError("Mask is empty (all black).")
|
||
|
||
y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2]
|
||
y_min, x_min = y_coords.min().item(), x_coords.min().item()
|
||
y_max, x_max = y_coords.max().item(), x_coords.max().item()
|
||
|
||
top_m, right_m, bottom_m, left_m = parse_margin(margin)
|
||
x_start_expanded, y_start_expanded = x_min - left_m, y_min - top_m
|
||
x_end_expanded, y_end_expanded = x_max + 1 + right_m, y_max + 1 + bottom_m
|
||
|
||
img_h, img_w = image.shape[1:3]
|
||
|
||
clamped_x_start = max(0, x_start_expanded)
|
||
clamped_y_start = max(0, y_start_expanded)
|
||
clamped_x_end = min(img_w, x_end_expanded)
|
||
clamped_y_end = min(img_h, y_end_expanded)
|
||
|
||
initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start
|
||
if initial_w <= 0 or initial_h <= 0:
|
||
raise ValueError("Cropped area has zero dimension.")
|
||
|
||
supported_resolutions = RESOLUTION_MAP.get(resolutions, SD_RESOLUTIONS)
|
||
diffs = [(abs(res[0] / res[1] - (initial_w / initial_h)), res) for res in supported_resolutions]
|
||
target_res = min(diffs, key=lambda x: x[0])[1]
|
||
target_ar = target_res[0] / target_res[1]
|
||
|
||
current_ar = initial_w / initial_h
|
||
if current_ar > target_ar:
|
||
cover_w, cover_h = float(initial_w), float(initial_w) / target_ar
|
||
else:
|
||
cover_h, cover_w = float(initial_h), float(initial_h) * target_ar
|
||
|
||
if cover_w > img_w or cover_h > img_h:
|
||
final_x, final_y, final_w, final_h = 0, 0, img_w, img_h
|
||
full_img_ar = img_w / img_h
|
||
diffs_full = [(abs(res[0] / res[1] - full_img_ar), res) for res in supported_resolutions]
|
||
target_res = min(diffs_full, key=lambda x: x[0])[1]
|
||
else:
|
||
center_x = clamped_x_start + initial_w / 2
|
||
center_y = clamped_y_start + initial_h / 2
|
||
final_x, final_y = center_x - cover_w / 2, center_y - cover_h / 2
|
||
final_w, final_h = cover_w, cover_h
|
||
|
||
if final_x < 0:
|
||
final_x = 0
|
||
if final_y < 0:
|
||
final_y = 0
|
||
if final_x + final_w > img_w:
|
||
final_x = img_w - final_w
|
||
if final_y + final_h > img_h:
|
||
final_y = img_h - final_h
|
||
|
||
final_x, final_y, final_w, final_h = int(final_x), int(final_y), int(final_w), int(final_h)
|
||
|
||
cropped_image = image[:, final_y:final_y + final_h, final_x:final_x + final_w]
|
||
cropped_mask = mask[:, final_y:final_y + final_h, final_x:final_x + final_w]
|
||
|
||
resized_image = F.interpolate(cropped_image.permute(0, 3, 1, 2), size=(target_res[1], target_res[0]), mode="bilinear", align_corners=False).permute(0, 2, 3, 1)
|
||
resized_mask = F.interpolate(cropped_mask.unsqueeze(1), size=(target_res[1], target_res[0]), mode="nearest").squeeze(1)
|
||
|
||
composite_context = CompositeContext(x=final_x, y=final_y, width=final_w, height=final_h)
|
||
return (resized_image, resized_mask, composite_context)
|
||
|
||
|
||
class CompositeCroppedAndFittedInpaintResult(CustomNode):
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"source_image": ("IMAGE",),
|
||
"source_mask": ("MASK",),
|
||
"inpainted_image": ("IMAGE",),
|
||
"composite_context": ("COMPOSITE_CONTEXT",),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
FUNCTION = "composite_result"
|
||
CATEGORY = "inpaint"
|
||
|
||
def composite_result(self, source_image: ImageBatch, source_mask: MaskBatch, inpainted_image: ImageBatch, composite_context: CompositeContext) -> tuple[ImageBatch]:
|
||
context_x, context_y, context_w, context_h = composite_context
|
||
|
||
resized_inpainted = F.interpolate(
|
||
inpainted_image.permute(0, 3, 1, 2),
|
||
size=(context_h, context_w),
|
||
mode="bilinear", align_corners=False
|
||
)
|
||
|
||
final_image = composite(
|
||
destination=source_image.clone().permute(0, 3, 1, 2),
|
||
source=resized_inpainted,
|
||
x=context_x,
|
||
y=context_y,
|
||
mask=source_mask
|
||
)
|
||
|
||
return final_image.permute(0, 2, 3, 1),
|
||
|
||
|
||
class ImageAndMaskResizeNode:
|
||
"""
|
||
Sherlocked from https://github.com/CY-CHENYUE/ComfyUI-InpaintEasy
|
||
|
||
MIT License
|
||
|
||
Copyright (c) 2024 CYCHENYUE
|
||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||
of this software and associated documentation files (the "Software"), to deal
|
||
in the Software without restriction, including without limitation the rights
|
||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||
copies of the Software, and to permit persons to whom the Software is
|
||
furnished to do so, subject to the following conditions:
|
||
|
||
The above copyright notice and this permission notice shall be included in all
|
||
copies or substantial portions of the Software.
|
||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||
SOFTWARE.
|
||
"""
|
||
DESCRIPTION = "Resize the image and mask simultaneously (from InpaintEasy- 同时调整图片和蒙版的大小)"
|
||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||
crop_methods = ["disabled", "center", "top_left", "top_right", "bottom_left", "bottom_right"]
|
||
|
||
def __init__(self):
|
||
self.type = "ImageMaskResize"
|
||
self.output_node = True
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",),
|
||
"mask": ("MASK",),
|
||
"width": ("INT", {
|
||
"default": 512,
|
||
"min": 64,
|
||
"max": 8192,
|
||
"step": 8
|
||
}),
|
||
"height": ("INT", {
|
||
"default": 512,
|
||
"min": 64,
|
||
"max": 8192,
|
||
"step": 8
|
||
}),
|
||
"resize_method": (s.upscale_methods, {"default": "lanczos"}),
|
||
"crop": (s.crop_methods, {"default": "disabled"}),
|
||
"mask_blur_radius": ("INT", {
|
||
"default": 10,
|
||
"min": 0,
|
||
"max": 64,
|
||
"step": 1
|
||
}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE", "MASK",)
|
||
RETURN_NAMES = ("image", "mask",)
|
||
FUNCTION = "resize_image_and_mask"
|
||
|
||
CATEGORY = "inpaint"
|
||
|
||
def resize_image_and_mask(self, image, mask, width, height, resize_method="lanczos", crop="disabled", mask_blur_radius=0):
|
||
# 处理宽高为0的情况
|
||
if width == 0 and height == 0:
|
||
return (image, mask)
|
||
|
||
# 对于图像的处理
|
||
samples = image.movedim(-1, 1) # NHWC -> NCHW
|
||
if width == 0:
|
||
width = max(1, round(samples.shape[3] * height / samples.shape[2]))
|
||
elif height == 0:
|
||
height = max(1, round(samples.shape[2] * width / samples.shape[3]))
|
||
|
||
# 使用 torch.nn.functional 直接进行缩放和裁剪
|
||
if crop != "disabled":
|
||
old_width = samples.shape[3]
|
||
old_height = samples.shape[2]
|
||
|
||
# 计算缩放比例
|
||
scale = max(width / old_width, height / old_height)
|
||
scaled_width = int(old_width * scale)
|
||
scaled_height = int(old_height * scale)
|
||
|
||
# 使用 common_upscale 进行缩放
|
||
samples = comfy.utils.common_upscale(samples, scaled_width, scaled_height, resize_method, crop="disabled")
|
||
|
||
# 蒙版始终使用bilinear插值
|
||
mask = F.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(scaled_height, scaled_width), mode='bilinear', align_corners=True)
|
||
|
||
# 计算裁剪位置
|
||
crop_x = 0
|
||
crop_y = 0
|
||
|
||
if crop == "center":
|
||
crop_x = (scaled_width - width) // 2
|
||
crop_y = (scaled_height - height) // 2
|
||
elif crop == "top_left":
|
||
crop_x = 0
|
||
crop_y = 0
|
||
elif crop == "top_right":
|
||
crop_x = scaled_width - width
|
||
crop_y = 0
|
||
elif crop == "bottom_left":
|
||
crop_x = 0
|
||
crop_y = scaled_height - height
|
||
elif crop == "bottom_right":
|
||
crop_x = scaled_width - width
|
||
crop_y = scaled_height - height
|
||
elif crop == "random":
|
||
crop_x = torch.randint(0, max(1, scaled_width - width), (1,)).item()
|
||
crop_y = torch.randint(0, max(1, scaled_height - height), (1,)).item()
|
||
|
||
# 执行裁剪
|
||
samples = samples[:, :, crop_y:crop_y + height, crop_x:crop_x + width]
|
||
mask = mask[:, :, crop_y:crop_y + height, crop_x:crop_x + width]
|
||
else:
|
||
# 直接使用 common_upscale 调整大小
|
||
samples = comfy.utils.common_upscale(samples, width, height, resize_method, crop="disabled")
|
||
mask = F.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(height, width), mode='bilinear', align_corners=True)
|
||
|
||
image_resized = samples.movedim(1, -1) # NCHW -> NHWC
|
||
mask_resized = mask.squeeze(1) # NCHW -> NHW
|
||
|
||
# 在返回之前添加高斯模糊处理
|
||
if mask_blur_radius > 0:
|
||
# 创建高斯核
|
||
kernel_size = mask_blur_radius * 2 + 1
|
||
x = torch.arange(kernel_size, dtype=torch.float32, device=mask_resized.device)
|
||
x = x - (kernel_size - 1) / 2
|
||
gaussian = torch.exp(-(x ** 2) / (2 * (mask_blur_radius / 3) ** 2))
|
||
gaussian = gaussian / gaussian.sum()
|
||
|
||
# 将kernel转换为2D
|
||
gaussian_2d = gaussian.view(1, -1) * gaussian.view(-1, 1)
|
||
gaussian_2d = gaussian_2d.view(1, 1, kernel_size, kernel_size)
|
||
|
||
# 应用高斯模糊
|
||
mask_for_blur = mask_resized.unsqueeze(1) # Add channel dimension
|
||
# 对边界进行padding,使用reflect模式避免边缘问题
|
||
padding = kernel_size // 2
|
||
mask_padded = F.pad(mask_for_blur, (padding, padding, padding, padding), mode='reflect')
|
||
mask_resized = F.conv2d(mask_padded, gaussian_2d.to(mask_resized.device), padding=0).squeeze(1)
|
||
|
||
# 确保值在0-1范围内
|
||
mask_resized = torch.clamp(mask_resized, 0, 1)
|
||
|
||
return (image_resized, mask_resized)
|
||
|
||
|
||
NODE_CLASS_MAPPINGS = {
|
||
"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize,
|
||
"CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult,
|
||
"ImageAndMaskResizeNode": ImageAndMaskResizeNode
|
||
}
|
||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||
"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region",
|
||
"CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result",
|
||
"ImageAndMaskResizeNode": "Image and Mask Resize"
|
||
}
|