diff --git a/comfy_extras/constants/resolutions.py b/comfy_extras/constants/resolutions.py index 423c70642..542b9e338 100644 --- a/comfy_extras/constants/resolutions.py +++ b/comfy_extras/constants/resolutions.py @@ -89,3 +89,5 @@ WAN_VIDEO_14B_EXTENDED_RESOLUTIONS = [ (704, 544), (544, 704) ] + +RESOLUTION_NAMES = ["SDXL/SD3/Flux", "SD1.5", "LTXV", "Ideogram", "Cosmos", "HunyuanVideo", "WAN 14b", "WAN 1.3b", "WAN 14b with extras"] \ No newline at end of file diff --git a/comfy_extras/nodes/nodes_images.py b/comfy_extras/nodes/nodes_images.py index a999d24e6..3865bd61b 100644 --- a/comfy_extras/nodes/nodes_images.py +++ b/comfy_extras/nodes/nodes_images.py @@ -16,7 +16,7 @@ from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.package_typing import CustomNode from comfy_extras.constants.resolutions import SDXL_SD3_FLUX_RESOLUTIONS, LTVX_RESOLUTIONS, SD_RESOLUTIONS, \ IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS, WAN_VIDEO_14B_RESOLUTIONS, \ - WAN_VIDEO_1_3B_RESOLUTIONS, WAN_VIDEO_14B_EXTENDED_RESOLUTIONS + WAN_VIDEO_1_3B_RESOLUTIONS, WAN_VIDEO_14B_EXTENDED_RESOLUTIONS, RESOLUTION_NAMES def levels_adjustment(image: ImageBatch, black_level: float = 0.0, mid_level: float = 0.5, white_level: float = 1.0, clip: bool = True) -> ImageBatch: @@ -274,7 +274,7 @@ class ImageResize: "required": { "image": ("IMAGE",), "resize_mode": (["cover", "contain", "auto"], {"default": "cover"}), - "resolutions": (["SDXL/SD3/Flux", "SD1.5", "LTXV", "Ideogram", "Cosmos", "HunyuanVideo", "WAN 14b", "WAN 1.3b", "WAN 14b with extras"], {"default": "SDXL/SD3/Flux"}), + "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), "interpolation": (ImageScale.upscale_methods, {"default": "lanczos"}), }, "optional": { @@ -313,7 +313,6 @@ class ImageResize: h, w = img.shape[:2] current_aspect_ratio = w / h - aspect_ratio_diffs = [(abs(res[0] / res[1] - current_aspect_ratio), res) for res in supported_resolutions] min_diff = min(aspect_ratio_diffs, key=lambda x: x[0])[0] close_enough_resolutions = [res for diff, res in aspect_ratio_diffs if diff <= min_diff + aspect_ratio_tolerance] diff --git a/comfy_extras/nodes/nodes_inpainting.py b/comfy_extras/nodes/nodes_inpainting.py new file mode 100644 index 000000000..5aba49af1 --- /dev/null +++ b/comfy_extras/nodes/nodes_inpainting.py @@ -0,0 +1,244 @@ +import torch +import torch +import torch.nn.functional as F + +from comfy.component_model.tensor_types import MaskBatch +from comfy_extras.constants.resolutions import RESOLUTION_NAMES +from comfy_extras.nodes.nodes_images import ImageResize + + +# Helper function from the context to composite images +def composite(destination, source, x, y, mask=None, multiplier=1, resize_source=False): + # This function is adapted from the provided context code + source = source.to(destination.device) + if resize_source: + source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") + + # Ensure source has the same batch size as destination + if source.shape[0] != destination.shape[0]: + source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1) + + x = int(x) + y = int(y) + + left, top = (x, y) + right, bottom = (left + source.shape[3], top + source.shape[2]) + + if mask is None: + # If no mask is provided, create a full-coverage mask + mask = torch.ones_like(source) + else: + # Ensure mask is on the correct device and is the correct size + mask = mask.to(destination.device, copy=True) + # Check if the mask is 2D (H, W) or 3D (B, H, W) and unsqueeze if necessary + if mask.dim() == 2: + mask = mask.unsqueeze(0) + if mask.dim() == 3: + mask = mask.unsqueeze(1) # Add channel dimension + mask = torch.nn.functional.interpolate(mask, size=(source.shape[2], source.shape[3]), mode="bilinear") + if mask.shape[0] != source.shape[0]: + mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1) + + # Define the bounds of the overlapping area + dest_left = max(0, left) + dest_top = max(0, top) + dest_right = min(destination.shape[3], right) + dest_bottom = min(destination.shape[2], bottom) + + # If there is no overlap, return the original destination + if dest_right <= dest_left or dest_bottom <= dest_top: + return destination + + # Calculate the source coordinates corresponding to the overlap + src_left = dest_left - left + src_top = dest_top - top + src_right = dest_right - left + src_bottom = dest_bottom - top + + # Crop the relevant portions of the destination, source, and mask + destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right] + source_portion = source[:, :, src_top:src_bottom, src_left:src_right] + mask_portion = mask[:, :, src_top:src_bottom, src_left:src_right] + + inverse_mask_portion = 1.0 - mask_portion + + # Perform the composition + blended_portion = (source_portion * mask_portion) + (destination_portion * inverse_mask_portion) + + # Place the blended portion back into the destination + destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion + + return destination + + +def parse_margin(margin_str: str) -> tuple[int, int, int, int]: + """Parses a CSS-style margin string.""" + parts = [int(p) for p in margin_str.strip().split()] + if len(parts) == 1: + return parts[0], parts[0], parts[0], parts[0] + if len(parts) == 2: + return parts[0], parts[1], parts[0], parts[1] + if len(parts) == 3: + return parts[0], parts[1], parts[2], parts[1] + if len(parts) == 4: + return parts[0], parts[1], parts[2], parts[3] + raise ValueError("Invalid margin format. Use 1 to 4 integer values.") + + +class CropAndFitInpaintToDiffusionSize: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "mask": ("MASK",), + "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), + "margin": ("STRING", {"default": "64"}), + "overflow": ("BOOLEAN", {"default": True}), + } + } + + RETURN_TYPES = ("IMAGE", "MASK", "COMBO[INT]") + RETURN_NAMES = ("image", "mask", "composite_context") + FUNCTION = "crop_and_fit" + CATEGORY = "inpaint" + + def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, overflow: bool): + # 1. Find bounding box of the mask + if mask.max() <= 0: + raise ValueError("Mask is empty, cannot determine bounding box.") + + # Find the coordinates of non-zero mask pixels + mask_coords = torch.nonzero(mask[0]) # Assuming single batch for mask + if mask_coords.numel() == 0: + raise ValueError("Mask is empty, cannot determine bounding box.") + + y_min, x_min = mask_coords.min(dim=0).values + y_max, x_max = mask_coords.max(dim=0).values + + # 2. Parse and apply margin + top_margin, right_margin, bottom_margin, left_margin = parse_margin(margin) + + x_start = x_min.item() - left_margin + y_start = y_min.item() - top_margin + x_end = x_max.item() + 1 + right_margin + y_end = y_max.item() + 1 + bottom_margin + + img_height, img_width = image.shape[1:3] + + # Store pre-crop context for the compositor node + context = { + "x": x_start, + "y": y_start, + "width": x_end - x_start, + "height": y_end - y_start + } + + # 3. Handle overflow + padded_image = image + padded_mask = mask + + pad_left = -min(0, x_start) + pad_top = -min(0, y_start) + pad_right = max(0, x_end - img_width) + pad_bottom = max(0, y_end - img_height) + + if any([pad_left, pad_top, pad_right, pad_bottom]): + if not overflow: + # Crop margin to fit within the image + x_start = max(0, x_start) + y_start = max(0, y_start) + x_end = min(img_width, x_end) + y_end = min(img_height, y_end) + else: + # Extend image and mask + padding = (pad_left, pad_right, pad_top, pad_bottom) + # Pad image with gray + padded_image = F.pad(image.permute(0, 3, 1, 2), padding, "constant", 0.5).permute(0, 2, 3, 1) + # Pad mask with zeros + padded_mask = F.pad(mask.unsqueeze(1), padding, "constant", 0).squeeze(1) + + # Adjust coordinates for the new padded space + x_start += pad_left + y_start += pad_top + x_end += pad_left + y_end += pad_top + + # 4. Crop image and mask + cropped_image = padded_image[:, y_start:y_end, x_start:x_end, :] + cropped_mask = padded_mask[:, y_start:y_end, x_start:x_end] + + # 5. Resize to a supported resolution + resizer = ImageResize() + resized_image, = resizer.resize_image(cropped_image, "cover", resolutions, "lanczos") + + # Resize mask similarly. Convert to image-like tensor for resizing. + cropped_mask_as_image = cropped_mask.unsqueeze(-1).repeat(1, 1, 1, 3) + resized_mask_as_image, = resizer.resize_image(cropped_mask_as_image, "cover", resolutions, "lanczos") + # Convert back to a mask (using the red channel) + resized_mask = resized_mask_as_image[:, :, :, 0] + + # Pack context into a list of ints for output + # Format: [x, y, width, height] + composite_context = (context["x"], context["y"], context["width"], context["height"]) + + return (resized_image, resized_mask, composite_context) + + +class CompositeCroppedAndFittedInpaintResult: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "source_image": ("IMAGE",), + "inpainted_image": ("IMAGE",), + "inpainted_mask": ("MASK",), + "composite_context": ("COMBO[INT]",), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "composite_result" + CATEGORY = "inpaint" + + def composite_result(self, source_image: torch.Tensor, inpainted_image: torch.Tensor, inpainted_mask: MaskBatch, composite_context: tuple[int, ...]): + # Unpack context + x, y, width, height = composite_context + + # The inpainted image and mask are at a diffusion resolution. Resize them back to the original crop size. + target_size = (height, width) + + # Resize inpainted image + inpainted_image_permuted = inpainted_image.movedim(-1, 1) + resized_inpainted_image = F.interpolate(inpainted_image_permuted, size=target_size, mode="bilinear", align_corners=False) + + # Resize inpainted mask + # Add channel dim: (B, H, W) -> (B, 1, H, W) + inpainted_mask_unsqueezed = inpainted_mask.unsqueeze(1) + resized_inpainted_mask = F.interpolate(inpainted_mask_unsqueezed, size=target_size, mode="bilinear", align_corners=False) + + # Prepare for compositing + destination_image = source_image.clone().movedim(-1, 1) + + # Composite the resized inpainted image back onto the source image + final_image_permuted = composite( + destination=destination_image, + source=resized_inpainted_image, + x=x, + y=y, + mask=resized_inpainted_mask + ) + + final_image = final_image_permuted.movedim(1, -1) + return (final_image,) + + +NODE_CLASS_MAPPINGS = { + "CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, + "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", + "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result", +} diff --git a/tests/unit/test_inpainting_utils.py b/tests/unit/test_inpainting_utils.py new file mode 100644 index 000000000..41d836b5e --- /dev/null +++ b/tests/unit/test_inpainting_utils.py @@ -0,0 +1,161 @@ +import pytest +import torch +import numpy as np + +# Assuming the node definitions are in a file named 'inpaint_nodes.py' +from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult, parse_margin + + +# Helper to create a circular mask +def create_circle_mask(height, width, center_y, center_x, radius): + """Creates a boolean mask with a filled circle.""" + Y, X = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + distance = torch.sqrt((Y - center_y) ** 2 + (X - center_x) ** 2) + mask = (distance <= radius).float() + return mask.unsqueeze(0) # Add batch dimension + + +@pytest.fixture +def sample_image() -> torch.Tensor: + """A 256x256 image with a vertical gradient.""" + gradient = torch.linspace(0, 1, 256).view(1, -1, 1, 1) + image = gradient.expand(1, 256, 256, 3) # (B, H, W, C) + return image + + +@pytest.fixture +def rect_mask() -> torch.Tensor: + """A rectangular mask in the center of a 256x256 image.""" + mask = torch.zeros(1, 256, 256) + mask[:, 100:150, 80:180] = 1.0 + return mask + + +@pytest.fixture +def circle_mask() -> torch.Tensor: + """A circular mask in a 256x256 image.""" + return create_circle_mask(256, 256, center_y=128, center_x=128, radius=50) + + +def test_parse_margin(): + """Tests the margin parsing utility function.""" + assert parse_margin("10") == (10, 10, 10, 10) + assert parse_margin(" 10 20 ") == (10, 20, 10, 20) + assert parse_margin("10 20 30") == (10, 20, 30, 20) + assert parse_margin("10 20 30 40") == (10, 20, 30, 40) + with pytest.raises(ValueError): + parse_margin("10 20 30 40 50") + with pytest.raises(ValueError): + parse_margin("not a number") + + +def test_crop_and_fit_basic(sample_image, rect_mask): + """Tests the basic functionality of the cropping and fitting node.""" + node = CropAndFitInpaintToDiffusionSize() + + # Using SD1.5 resolutions for predictability in tests + img, msk, ctx = node.crop_and_fit(sample_image, rect_mask, resolutions="SD1.5", margin="20", overflow=False) + + # Check output shapes + assert img.shape[0] == 1 and img.shape[3] == 3 + assert msk.shape[0] == 1 + # Check if resized to a valid SD1.5 resolution + assert (img.shape[2], img.shape[1]) in [(512, 512), (768, 512), (512, 768)] + assert img.shape[1:3] == msk.shape[1:3] + + # Check context + # Original mask bounds: y(100, 149), x(80, 179) + # With margin 20: y(80, 169), x(60, 199) + # context is (x, y, width, height) + expected_x = 80 - 20 + expected_y = 100 - 20 + expected_width = (180 - 80) + 2 * 20 + expected_height = (150 - 100) + 2 * 20 + + assert ctx == (expected_x, expected_y, expected_width, expected_height) + + +def test_crop_and_fit_overflow(sample_image, rect_mask): + """Tests the overflow logic by placing the mask at an edge.""" + node = CropAndFitInpaintToDiffusionSize() + edge_mask = torch.zeros_like(rect_mask) + edge_mask[:, :20, :50] = 1.0 # Mask at the top-left corner + + # Test with overflow disabled (should clamp) + _, _, ctx_no_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=False) + assert ctx_no_overflow == (0, 0, 50 + 30, 20 + 30) + + # Test with overflow enabled + img, msk, ctx_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=True) + # Context should have negative coordinates + # Original bounds: y(0, 19), x(0, 49) + # Margin 30: y(-30, 49), x(-30, 79) + assert ctx_overflow == (-30, -30, (50 - 0) + 60, (20 - 0) + 60) + + # Check that padded area is gray + # The original image was placed inside a larger gray canvas. + # We check a pixel that should be in the padded gray area of the *cropped* image. + # The crop starts at y=-30, x=-30 relative to original image. + # So, pixel (5,5) in the cropped image corresponds to (-25, -25) which is padding. + assert torch.allclose(img[0, 5, 5, :], torch.tensor([0.5, 0.5, 0.5])) + + # Check that original image content is still there + # Pixel (40, 40) in cropped image corresponds to (10, 10) in original image + assert torch.allclose(img[0, 40, 40, :], sample_image[0, 10, 10, :]) + + +def test_empty_mask_raises_error(sample_image): + """Tests that an empty mask correctly raises a ValueError.""" + node = CropAndFitInpaintToDiffusionSize() + empty_mask = torch.zeros(1, 256, 256) + with pytest.raises(ValueError, match="Mask is empty"): + node.crop_and_fit(sample_image, empty_mask, "SD1.5", "10", False) + + +@pytest.mark.parametrize("mask_fixture, margin, overflow", [ + ("rect_mask", "16", False), + ("circle_mask", "32", False), + ("rect_mask", "64", True), # margin forces overflow + ("circle_mask", "0", False), +]) +def test_end_to_end_composition(request, sample_image, mask_fixture, margin, overflow): + """Performs a full round-trip test of both nodes.""" + mask = request.getfixturevalue(mask_fixture) + + # --- 1. Crop and Fit --- + crop_node = CropAndFitInpaintToDiffusionSize() + cropped_img, cropped_mask, context = crop_node.crop_and_fit( + sample_image, mask, "SD1.5", margin, overflow + ) + + # --- 2. Simulate Inpainting --- + # Create a solid blue image as the "inpainted" result + h, w = cropped_img.shape[1:3] + blue_color = torch.tensor([0.1, 0.2, 0.9]).view(1, 1, 1, 3) + inpainted_sim = blue_color.expand(1, h, w, 3) + # The inpainted_mask is the mask output from the first node + inpainted_mask = cropped_mask + + # --- 3. Composite Result --- + composite_node = CompositeCroppedAndFittedInpaintResult() + final_image, = composite_node.composite_result( + source_image=sample_image, + inpainted_image=inpainted_sim, + inpainted_mask=inpainted_mask, + composite_context=context + ) + + # --- 4. Verify Result --- + assert final_image.shape == sample_image.shape + + # Create a boolean version of the original mask for easy indexing + bool_mask = mask.squeeze(0).bool() # H, W + + # Area *inside* the mask should be blue + masked_area_in_final = final_image[0][bool_mask] + assert torch.allclose(masked_area_in_final, blue_color.squeeze(), atol=1e-2) + + # Area *outside* the mask should be unchanged from the original + unmasked_area_in_final = final_image[0][~bool_mask] + unmasked_area_in_original = sample_image[0][~bool_mask] + assert torch.allclose(unmasked_area_in_final, unmasked_area_in_original, atol=1e-2)