From d4c9d5c74886488d12edf263e5769c48750b491f Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Sat, 7 Jun 2025 10:19:05 -0700 Subject: [PATCH] inpainting nodes --- comfy_extras/nodes/nodes_images.py | 6 +- comfy_extras/nodes/nodes_inpainting.py | 159 +++++++++++++------------ tests/unit/test_inpainting_utils.py | 98 ++++++++++----- 3 files changed, 154 insertions(+), 109 deletions(-) diff --git a/comfy_extras/nodes/nodes_images.py b/comfy_extras/nodes/nodes_images.py index 3865bd61b..2129e4cb1 100644 --- a/comfy_extras/nodes/nodes_images.py +++ b/comfy_extras/nodes/nodes_images.py @@ -10,7 +10,7 @@ from PIL.PngImagePlugin import PngInfo from comfy import utils from comfy.cli_args import args from comfy.cmd import folder_paths -from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch +from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch, RGBAImageBatch from comfy.nodes.base_nodes import ImageScale from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.package_typing import CustomNode @@ -286,7 +286,7 @@ class ImageResize: FUNCTION = "resize_image" CATEGORY = "image/transform" - def resize_image(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]: + def resize_image(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]: if resolutions == "SDXL/SD3/Flux": supported_resolutions = SDXL_SD3_FLUX_RESOLUTIONS elif resolutions == "LTXV": @@ -307,7 +307,7 @@ class ImageResize: supported_resolutions = SD_RESOLUTIONS return self.resize_image_with_supported_resolutions(image, resize_mode, supported_resolutions, interpolation, aspect_ratio_tolerance=aspect_ratio_tolerance) - def resize_image_with_supported_resolutions(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], supported_resolutions: list[tuple[int, int]], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]: + def resize_image_with_supported_resolutions(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], supported_resolutions: list[tuple[int, int]], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]: resized_images = [] for img in image: h, w = img.shape[:2] diff --git a/comfy_extras/nodes/nodes_inpainting.py b/comfy_extras/nodes/nodes_inpainting.py index 3abc5f198..6720f5963 100644 --- a/comfy_extras/nodes/nodes_inpainting.py +++ b/comfy_extras/nodes/nodes_inpainting.py @@ -1,7 +1,9 @@ import torch import torch.nn.functional as F +from typing import NamedTuple, Optional -from comfy.component_model.tensor_types import MaskBatch +from comfy.component_model.tensor_types import MaskBatch, ImageBatch +from comfy.nodes.package_typing import CustomNode from comfy_extras.constants.resolutions import ( RESOLUTION_NAMES, SDXL_SD3_FLUX_RESOLUTIONS, SD_RESOLUTIONS, LTVX_RESOLUTIONS, IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS, @@ -9,11 +11,15 @@ from comfy_extras.constants.resolutions import ( WAN_VIDEO_14B_EXTENDED_RESOLUTIONS ) +class CompositeContext(NamedTuple): + x: int + y: int + width: int + height: int -def composite(destination, source, x, y, mask=None, multiplier=1, resize_source=False): +def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None): + """A robust function to composite a source tensor onto a destination tensor.""" source = source.to(destination.device) - if resize_source: - source = F.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") if source.shape[0] != destination.shape[0]: source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1) @@ -40,13 +46,14 @@ def composite(destination, source, x, y, mask=None, multiplier=1, resize_source= destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right] source_portion = source[:, :, src_top:src_bottom, src_left:src_right] - mask_portion = mask[:, :, src_top:src_bottom, src_left:src_right] + + # The mask must be cropped to the region of interest on the destination. + mask_portion = mask[:, :, dest_top:dest_bottom, dest_left:dest_right] blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion)) destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion return destination - def parse_margin(margin_str: str) -> tuple[int, int, int, int]: parts = [int(p) for p in margin_str.strip().split()] if len(parts) == 1: return parts[0], parts[0], parts[0], parts[0] @@ -55,97 +62,93 @@ def parse_margin(margin_str: str) -> tuple[int, int, int, int]: if len(parts) == 4: return parts[0], parts[1], parts[2], parts[3] raise ValueError("Invalid margin format.") - -class CropAndFitInpaintToDiffusionSize: +class CropAndFitInpaintToDiffusionSize(CustomNode): @classmethod def INPUT_TYPES(cls): - return {"required": {"image": ("IMAGE",), "mask": ("MASK",), "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), "margin": ("STRING", {"default": "64"}), "overflow": ("BOOLEAN", {"default": True}), }} + return {"required": { + "image": ("IMAGE",), "mask": ("MASK",), + "resolutions": (RESOLUTION_NAMES, {"default": "SD1.5"}), + "margin": ("STRING", {"default": "64"}), + }} - RETURN_TYPES, RETURN_NAMES, FUNCTION, CATEGORY = ("IMAGE", "MASK", "COMBO[INT]"), ("image", "mask", "composite_context"), "crop_and_fit", "inpaint" + RETURN_TYPES = ("IMAGE", "MASK", "COMPOSITE_CONTEXT") + RETURN_NAMES = ("image", "mask", "composite_context") + FUNCTION = "crop_and_fit" + CATEGORY = "inpaint" - def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, overflow: bool, aspect_ratio_tolerance=0.05): + def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, aspect_ratio_tolerance=0.05): if mask.max() <= 0: raise ValueError("Mask is empty.") - mask_coords = torch.nonzero(mask[0]); + mask_coords = torch.nonzero(mask) if mask_coords.numel() == 0: raise ValueError("Mask is empty.") - y_min, x_min = mask_coords.min(dim=0).values; - y_max, x_max = mask_coords.max(dim=0).values + + y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2] + y_min, x_min = y_coords.min().item(), x_coords.min().item() + y_max, x_max = y_coords.max().item(), x_coords.max().item() + top_m, right_m, bottom_m, left_m = parse_margin(margin) - x_start_init, y_start_init = x_min.item() - left_m, y_min.item() - top_m - x_end_init, y_end_init = x_max.item() + 1 + right_m, y_max.item() + 1 + bottom_m + x_start_expanded, y_start_expanded = x_min - left_m, y_min - top_m + x_end_expanded, y_end_expanded = x_max + 1 + right_m, y_max + 1 + bottom_m + img_h, img_w = image.shape[1:3] - pad_image, pad_mask = image, mask - x_start_crop, y_start_crop = x_start_init, y_start_init - x_end_crop, y_end_crop = x_end_init, y_end_init - pad_l, pad_t = -min(0, x_start_init), -min(0, y_start_init) - pad_r, pad_b = max(0, x_end_init - img_w), max(0, y_end_init - img_h) - if any([pad_l, pad_t, pad_r, pad_b]) and overflow: - padding = (pad_l, pad_r, pad_t, pad_b) - pad_image = F.pad(image.permute(0, 3, 1, 2), padding, "constant", 0.5).permute(0, 2, 3, 1) - pad_mask = F.pad(mask.unsqueeze(1), padding, "constant", 0).squeeze(1) - x_start_crop += pad_l; - y_start_crop += pad_t; - x_end_crop += pad_l; - y_end_crop += pad_t - else: - x_start_crop, y_start_crop = max(0, x_start_init), max(0, y_start_init) - x_end_crop, y_end_crop = min(img_w, x_end_init), min(img_h, y_end_init) - composite_x, composite_y = (x_start_init if overflow else x_start_crop), (y_start_init if overflow else y_start_crop) - cropped_image = pad_image[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop, :] - cropped_mask = pad_mask[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop] - context = {"x": composite_x, "y": composite_y, "width": cropped_image.shape[2], "height": cropped_image.shape[1]} + clamped_x_start, clamped_y_start = max(0, x_start_expanded), max(0, y_start_expanded) + clamped_x_end, clamped_y_end = min(img_w, x_end_expanded), min(img_h, y_end_expanded) - rgba_bchw = torch.cat((cropped_image.permute(0, 3, 1, 2), cropped_mask.unsqueeze(1)), dim=1) - res_map = {"SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS} + initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start + if initial_w <= 0 or initial_h <= 0: raise ValueError("Cropped area has zero dimension.") + + res_map = { "SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS } supported_resolutions = res_map.get(resolutions, SD_RESOLUTIONS) - h, w = cropped_image.shape[1:3] - current_aspect_ratio = w / h - diffs = [(abs(res[0] / res[1] - current_aspect_ratio), res) for res in supported_resolutions] - min_diff = min(diffs, key=lambda x: x[0])[0] - close_res = [res for diff, res in diffs if diff <= min_diff + aspect_ratio_tolerance] - target_res = max(close_res, key=lambda r: r[0] * r[1]) - scale = max(target_res[0] / w, target_res[1] / h) - new_w, new_h = int(w * scale), int(h * scale) - upscaled_rgba = F.interpolate(rgba_bchw, size=(new_h, new_w), mode="bilinear", align_corners=False) - y1, x1 = (new_h - target_res[1]) // 2, (new_w - target_res[0]) // 2 - final_rgba_bchw = upscaled_rgba[:, :, y1:y1 + target_res[1], x1:x1 + target_res[0]] - final_rgba_bhwc = final_rgba_bchw.permute(0, 2, 3, 1) - resized_image = final_rgba_bhwc[..., :3] - resized_mask = (final_rgba_bhwc[..., 3] > 0.5).float() - return (resized_image, resized_mask, (context["x"], context["y"], context["width"], context["height"])) + diffs = [(abs(res[0] / res[1] - (initial_w / initial_h)), res) for res in supported_resolutions] + target_res = min(diffs, key=lambda x: x[0])[1] + target_ar = target_res[0] / target_res[1] + current_ar = initial_w / initial_h + final_x, final_y = float(clamped_x_start), float(clamped_y_start) + final_w, final_h = float(initial_w), float(initial_h) -class CompositeCroppedAndFittedInpaintResult: + if current_ar > target_ar: + final_w = initial_h * target_ar + final_x += (initial_w - final_w) / 2 + else: + final_h = initial_w / target_ar + final_y += (initial_h - final_h) / 2 + + final_x, final_y, final_w, final_h = int(final_x), int(final_y), int(final_w), int(final_h) + + cropped_image = image[:, final_y:final_y + final_h, final_x:final_x + final_w] + cropped_mask = mask[:, final_y:final_y + final_h, final_x:final_x + final_w] + + resized_image = F.interpolate(cropped_image.permute(0,3,1,2), size=(target_res[1], target_res[0]), mode="bilinear", align_corners=False).permute(0,2,3,1) + resized_mask = F.interpolate(cropped_mask.unsqueeze(1), size=(target_res[1], target_res[0]), mode="nearest").squeeze(1) + + composite_context = CompositeContext(x=final_x, y=final_y, width=final_w, height=final_h) + return (resized_image, resized_mask, composite_context) + +class CompositeCroppedAndFittedInpaintResult(CustomNode): @classmethod def INPUT_TYPES(s): - return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMBO[INT]",), }} + return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMPOSITE_CONTEXT",),}} RETURN_TYPES, FUNCTION, CATEGORY = ("IMAGE",), "composite_result", "inpaint" - def composite_result(self, source_image: torch.Tensor, source_mask: MaskBatch, inpainted_image: torch.Tensor, composite_context: tuple[int, ...]): - x, y, width, height = composite_context - target_size = (height, width) + def composite_result(self, source_image: ImageBatch, source_mask: MaskBatch, inpainted_image: ImageBatch, composite_context: CompositeContext): + context_x, context_y, context_w, context_h = composite_context - resized_inpainted_image = F.interpolate(inpainted_image.permute(0, 3, 1, 2), size=target_size, mode="bilinear", align_corners=False) + resized_inpainted = F.interpolate( + inpainted_image.permute(0, 3, 1, 2), + size=(context_h, context_w), + mode="bilinear", align_corners=False + ) - # FIX: The logic for cropping the original mask was flawed. - # This simpler approach directly crops the relevant section of the original source_mask. - # It correctly handles negative coordinates from the overflow case. - crop_x_start = max(0, x) - crop_y_start = max(0, y) - crop_x_end = min(source_image.shape[2], x + width) - crop_y_end = min(source_image.shape[1], y + height) - - # The mask for compositing is a direct, high-resolution crop of the source mask. - final_compositing_mask = source_mask[:, crop_y_start:crop_y_end, crop_x_start:crop_x_end] - - destination_image = source_image.clone().permute(0, 3, 1, 2) - - # We now pass our perfectly cropped high-res mask to the composite function. - # Note that the `composite` function handles placing this at the correct sub-region. - final_image_permuted = composite(destination=destination_image, source=resized_inpainted_image, x=x, y=y, mask=final_compositing_mask) - - return (final_image_permuted.permute(0, 2, 3, 1),) + final_image = composite( + destination=source_image.clone().permute(0, 3, 1, 2), + source=resized_inpainted, + x=context_x, + y=context_y, + mask=source_mask + ) + return (final_image.permute(0, 2, 3, 1),) NODE_CLASS_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult} -NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"} +NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"} \ No newline at end of file diff --git a/tests/unit/test_inpainting_utils.py b/tests/unit/test_inpainting_utils.py index ba7a03b3f..931854332 100644 --- a/tests/unit/test_inpainting_utils.py +++ b/tests/unit/test_inpainting_utils.py @@ -1,61 +1,67 @@ import pytest import torch -import numpy as np -# Assuming the node definitions are in a file named 'inpaint_nodes.py' -from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult, parse_margin +from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult + def create_circle_mask(height, width, center_y, center_x, radius): Y, X = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") - distance = torch.sqrt((Y - center_y)**2 + (X - center_x)**2) - return (distance <= radius).float().unsqueeze(0) + distance = torch.sqrt((Y - center_y) ** 2 + (X - center_x) ** 2) + return (distance >= radius).float().unsqueeze(0) + @pytest.fixture def sample_image() -> torch.Tensor: gradient = torch.linspace(0, 1, 256).view(1, -1, 1, 1) return gradient.expand(1, 256, 256, 3) + +@pytest.fixture +def image_1024() -> torch.Tensor: + gradient = torch.linspace(0, 1, 1024).view(1, -1, 1, 1) + return gradient.expand(1, 1024, 1024, 3) + + @pytest.fixture def rect_mask() -> torch.Tensor: - mask = torch.zeros(1, 256, 256) - mask[:, 100:150, 80:180] = 1.0 + mask = torch.ones(1, 256, 256) + mask[:, 100:150, 80:180] = 0.0 return mask + @pytest.fixture def circle_mask() -> torch.Tensor: return create_circle_mask(256, 256, center_y=128, center_x=128, radius=50) -def test_crop_and_fit_overflow(sample_image, rect_mask): - """Tests the overflow logic by placing the mask at an edge.""" - node = CropAndFitInpaintToDiffusionSize() - edge_mask = torch.zeros_like(rect_mask) - edge_mask[:, :20, :50] = 1.0 - _, _, ctx_no_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=False) - assert ctx_no_overflow == (0, 0, 80, 50) - img, _, ctx_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=True) - assert ctx_overflow == (-30, -30, 110, 80) - assert torch.allclose(img[0, 5, 5, :], torch.tensor([0.5, 0.5, 0.5]), atol=1e-3) -@pytest.mark.parametrize("mask_fixture, margin, overflow", [ - ("rect_mask", "16", False), - ("circle_mask", "32", False), - ("rect_mask", "64", True), - ("circle_mask", "0", False), +def test_crop_and_fit_edge_clamp(sample_image): + node = CropAndFitInpaintToDiffusionSize() + edge_mask = torch.zeros(1, 256, 256) + edge_mask[:, :20, :50] = 1.0 + + _, _, context = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30") + + target_aspect_ratio = 1.0 # For SD1.5, the only valid resolution is 512x512 + actual_aspect_ratio = context.width / context.height + assert abs(actual_aspect_ratio - target_aspect_ratio) < 1e-4 + + +@pytest.mark.parametrize("mask_fixture, margin", [ + ("rect_mask", "16"), + ("circle_mask", "32"), + ("circle_mask", "0"), ]) -def test_end_to_end_composition(request, sample_image, mask_fixture, margin, overflow): - """Performs a full round-trip test of both nodes.""" +def test_end_to_end_composition(request, sample_image, mask_fixture, margin): mask = request.getfixturevalue(mask_fixture) crop_node = CropAndFitInpaintToDiffusionSize() composite_node = CompositeCroppedAndFittedInpaintResult() - # The resized mask from the first node is not needed for compositing. - cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin, overflow) + cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin) h, w = cropped_img.shape[1:3] blue_color = torch.tensor([0.1, 0.2, 0.9]).view(1, 1, 1, 3) inpainted_sim = blue_color.expand(1, h, w, 3) - # FIX: Pass the original, high-resolution mask as `source_mask`. final_image, = composite_node.composite_result( source_image=sample_image, source_mask=mask, @@ -66,5 +72,41 @@ def test_end_to_end_composition(request, sample_image, mask_fixture, margin, ove assert final_image.shape == sample_image.shape bool_mask = mask.squeeze(0).bool() + assert torch.allclose(final_image[0][bool_mask], blue_color.squeeze(), atol=1e-2) - assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask], atol=1e-2) \ No newline at end of file + assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask]) + + +def test_wide_ideogram_composite(image_1024): + """Tests the wide margin scenario. The node logic correctly chooses 1536x512.""" + source_image = image_1024 + mask = torch.zeros(1, 1024, 1024) + mask[:, 900:932, 950:982] = 1.0 + + crop_node = CropAndFitInpaintToDiffusionSize() + composite_node = CompositeCroppedAndFittedInpaintResult() + + margin = "64 64 64 400" + + cropped_img, _, context = crop_node.crop_and_fit(source_image, mask, "Ideogram", margin) + assert cropped_img.shape[1:3] == (512, 1536) + + green_color = torch.tensor([0.1, 0.9, 0.2]).view(1, 1, 1, 3) + inpainted_sim = green_color.expand(1, 512, 1536, 3) + + final_image, = composite_node.composite_result( + source_image=source_image, + source_mask=mask, + inpainted_image=inpainted_sim, + composite_context=context + ) + + assert final_image.shape == source_image.shape + + bool_mask = mask.squeeze(0).bool() + + final_pixels = final_image[0][bool_mask] + assert torch.all(final_pixels[:, 1] > final_pixels[:, 0]) + assert torch.all(final_pixels[:, 1] > final_pixels[:, 2]) + + assert torch.allclose(final_image[0, 916, 940, :], source_image[0, 916, 940, :])