ComfyUI/comfy_extras/nodes/nodes_inpainting.py

152 lines
8.5 KiB
Python

import torch
import torch.nn.functional as F
from comfy.component_model.tensor_types import MaskBatch
from comfy_extras.constants.resolutions import (
RESOLUTION_NAMES, SDXL_SD3_FLUX_RESOLUTIONS, SD_RESOLUTIONS, LTVX_RESOLUTIONS,
IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS,
WAN_VIDEO_14B_RESOLUTIONS, WAN_VIDEO_1_3B_RESOLUTIONS,
WAN_VIDEO_14B_EXTENDED_RESOLUTIONS
)
def composite(destination, source, x, y, mask=None, multiplier=1, resize_source=False):
source = source.to(destination.device)
if resize_source:
source = F.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
if source.shape[0] != destination.shape[0]:
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
x, y = int(x), int(y)
left, top = x, y
right, bottom = left + source.shape[3], top + source.shape[2]
if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.to(destination.device, copy=True)
if mask.dim() == 2: mask = mask.unsqueeze(0)
if mask.dim() == 3: mask = mask.unsqueeze(1)
if mask.shape[0] != source.shape[0]:
mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1)
dest_left, dest_top = max(0, left), max(0, top)
dest_right, dest_bottom = min(destination.shape[3], right), min(destination.shape[2], bottom)
if dest_right <= dest_left or dest_bottom <= dest_top: return destination
src_left, src_top = dest_left - left, dest_top - top
src_right, src_bottom = dest_right - left, dest_bottom
destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right]
source_portion = source[:, :, src_top:src_bottom, src_left:src_right]
mask_portion = mask[:, :, src_top:src_bottom, src_left:src_right]
blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion))
destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion
return destination
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
parts = [int(p) for p in margin_str.strip().split()]
if len(parts) == 1: return parts[0], parts[0], parts[0], parts[0]
if len(parts) == 2: return parts[0], parts[1], parts[0], parts[1]
if len(parts) == 3: return parts[0], parts[1], parts[2], parts[1]
if len(parts) == 4: return parts[0], parts[1], parts[2], parts[3]
raise ValueError("Invalid margin format.")
class CropAndFitInpaintToDiffusionSize:
@classmethod
def INPUT_TYPES(cls):
return {"required": {"image": ("IMAGE",), "mask": ("MASK",), "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), "margin": ("STRING", {"default": "64"}), "overflow": ("BOOLEAN", {"default": True}), }}
RETURN_TYPES, RETURN_NAMES, FUNCTION, CATEGORY = ("IMAGE", "MASK", "COMBO[INT]"), ("image", "mask", "composite_context"), "crop_and_fit", "inpaint"
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, overflow: bool, aspect_ratio_tolerance=0.05):
if mask.max() <= 0: raise ValueError("Mask is empty.")
mask_coords = torch.nonzero(mask[0]);
if mask_coords.numel() == 0: raise ValueError("Mask is empty.")
y_min, x_min = mask_coords.min(dim=0).values;
y_max, x_max = mask_coords.max(dim=0).values
top_m, right_m, bottom_m, left_m = parse_margin(margin)
x_start_init, y_start_init = x_min.item() - left_m, y_min.item() - top_m
x_end_init, y_end_init = x_max.item() + 1 + right_m, y_max.item() + 1 + bottom_m
img_h, img_w = image.shape[1:3]
pad_image, pad_mask = image, mask
x_start_crop, y_start_crop = x_start_init, y_start_init
x_end_crop, y_end_crop = x_end_init, y_end_init
pad_l, pad_t = -min(0, x_start_init), -min(0, y_start_init)
pad_r, pad_b = max(0, x_end_init - img_w), max(0, y_end_init - img_h)
if any([pad_l, pad_t, pad_r, pad_b]) and overflow:
padding = (pad_l, pad_r, pad_t, pad_b)
pad_image = F.pad(image.permute(0, 3, 1, 2), padding, "constant", 0.5).permute(0, 2, 3, 1)
pad_mask = F.pad(mask.unsqueeze(1), padding, "constant", 0).squeeze(1)
x_start_crop += pad_l;
y_start_crop += pad_t;
x_end_crop += pad_l;
y_end_crop += pad_t
else:
x_start_crop, y_start_crop = max(0, x_start_init), max(0, y_start_init)
x_end_crop, y_end_crop = min(img_w, x_end_init), min(img_h, y_end_init)
composite_x, composite_y = (x_start_init if overflow else x_start_crop), (y_start_init if overflow else y_start_crop)
cropped_image = pad_image[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop, :]
cropped_mask = pad_mask[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop]
context = {"x": composite_x, "y": composite_y, "width": cropped_image.shape[2], "height": cropped_image.shape[1]}
rgba_bchw = torch.cat((cropped_image.permute(0, 3, 1, 2), cropped_mask.unsqueeze(1)), dim=1)
res_map = {"SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS}
supported_resolutions = res_map.get(resolutions, SD_RESOLUTIONS)
h, w = cropped_image.shape[1:3]
current_aspect_ratio = w / h
diffs = [(abs(res[0] / res[1] - current_aspect_ratio), res) for res in supported_resolutions]
min_diff = min(diffs, key=lambda x: x[0])[0]
close_res = [res for diff, res in diffs if diff <= min_diff + aspect_ratio_tolerance]
target_res = max(close_res, key=lambda r: r[0] * r[1])
scale = max(target_res[0] / w, target_res[1] / h)
new_w, new_h = int(w * scale), int(h * scale)
upscaled_rgba = F.interpolate(rgba_bchw, size=(new_h, new_w), mode="bilinear", align_corners=False)
y1, x1 = (new_h - target_res[1]) // 2, (new_w - target_res[0]) // 2
final_rgba_bchw = upscaled_rgba[:, :, y1:y1 + target_res[1], x1:x1 + target_res[0]]
final_rgba_bhwc = final_rgba_bchw.permute(0, 2, 3, 1)
resized_image = final_rgba_bhwc[..., :3]
resized_mask = (final_rgba_bhwc[..., 3] > 0.5).float()
return (resized_image, resized_mask, (context["x"], context["y"], context["width"], context["height"]))
class CompositeCroppedAndFittedInpaintResult:
@classmethod
def INPUT_TYPES(s):
return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMBO[INT]",), }}
RETURN_TYPES, FUNCTION, CATEGORY = ("IMAGE",), "composite_result", "inpaint"
def composite_result(self, source_image: torch.Tensor, source_mask: MaskBatch, inpainted_image: torch.Tensor, composite_context: tuple[int, ...]):
x, y, width, height = composite_context
target_size = (height, width)
resized_inpainted_image = F.interpolate(inpainted_image.permute(0, 3, 1, 2), size=target_size, mode="bilinear", align_corners=False)
# FIX: The logic for cropping the original mask was flawed.
# This simpler approach directly crops the relevant section of the original source_mask.
# It correctly handles negative coordinates from the overflow case.
crop_x_start = max(0, x)
crop_y_start = max(0, y)
crop_x_end = min(source_image.shape[2], x + width)
crop_y_end = min(source_image.shape[1], y + height)
# The mask for compositing is a direct, high-resolution crop of the source mask.
final_compositing_mask = source_mask[:, crop_y_start:crop_y_end, crop_x_start:crop_x_end]
destination_image = source_image.clone().permute(0, 3, 1, 2)
# We now pass our perfectly cropped high-res mask to the composite function.
# Note that the `composite` function handles placing this at the correct sub-region.
final_image_permuted = composite(destination=destination_image, source=resized_inpainted_image, x=x, y=y, mask=final_compositing_mask)
return (final_image_permuted.permute(0, 2, 3, 1),)
NODE_CLASS_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult}
NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"}