From 41dc6ec7fa79f183a218a75d8584201f38769dd0 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 24 Jun 2025 12:27:06 -0700 Subject: [PATCH] fix inpainting alignment issue --- comfy_extras/nodes/nodes_inpainting.py | 46 ++++++--- tests/unit/test_inpainting_utils.py | 133 +++++++++++++++++++++++-- 2 files changed, 154 insertions(+), 25 deletions(-) diff --git a/comfy_extras/nodes/nodes_inpainting.py b/comfy_extras/nodes/nodes_inpainting.py index 30c2a63ab..4964bbb33 100644 --- a/comfy_extras/nodes/nodes_inpainting.py +++ b/comfy_extras/nodes/nodes_inpainting.py @@ -1,7 +1,6 @@ -from typing import NamedTuple, Optional - import torch import torch.nn.functional as F +from typing import NamedTuple, Optional from comfy.component_model.tensor_types import MaskBatch, ImageBatch from comfy.nodes.package_typing import CustomNode @@ -84,11 +83,12 @@ class CropAndFitInpaintToDiffusionSize(CustomNode): CATEGORY = "inpaint" def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str): - if mask.max() <= 0: - raise ValueError("Mask is empty.") + if mask.max() == 0.0: + raise ValueError("Mask is empty (all black).") + mask_coords = torch.nonzero(mask) if mask_coords.numel() == 0: - raise ValueError("Mask is empty.") + raise ValueError("Mask is empty (all black).") y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2] y_min, x_min = y_coords.min().item(), x_coords.min().item() @@ -99,8 +99,11 @@ class CropAndFitInpaintToDiffusionSize(CustomNode): x_end_expanded, y_end_expanded = x_max + 1 + right_m, y_max + 1 + bottom_m img_h, img_w = image.shape[1:3] - clamped_x_start, clamped_y_start = max(0, x_start_expanded), max(0, y_start_expanded) - clamped_x_end, clamped_y_end = min(img_w, x_end_expanded), min(img_h, y_end_expanded) + + clamped_x_start = max(0, x_start_expanded) + clamped_y_start = max(0, y_start_expanded) + clamped_x_end = min(img_w, x_end_expanded) + clamped_y_end = min(img_h, y_end_expanded) initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start if initial_w <= 0 or initial_h <= 0: @@ -112,15 +115,30 @@ class CropAndFitInpaintToDiffusionSize(CustomNode): target_ar = target_res[0] / target_res[1] current_ar = initial_w / initial_h - final_x, final_y = float(clamped_x_start), float(clamped_y_start) - final_w, final_h = float(initial_w), float(initial_h) - if current_ar > target_ar: - final_w = initial_h * target_ar - final_x += (initial_w - final_w) / 2 + cover_w, cover_h = float(initial_w), float(initial_w) / target_ar else: - final_h = initial_w / target_ar - final_y += (initial_h - final_h) / 2 + cover_h, cover_w = float(initial_h), float(initial_h) * target_ar + + if cover_w > img_w or cover_h > img_h: + final_x, final_y, final_w, final_h = 0, 0, img_w, img_h + full_img_ar = img_w / img_h + diffs_full = [(abs(res[0] / res[1] - full_img_ar), res) for res in supported_resolutions] + target_res = min(diffs_full, key=lambda x: x[0])[1] + else: + center_x = clamped_x_start + initial_w / 2 + center_y = clamped_y_start + initial_h / 2 + final_x, final_y = center_x - cover_w / 2, center_y - cover_h / 2 + final_w, final_h = cover_w, cover_h + + if final_x < 0: + final_x = 0 + if final_y < 0: + final_y = 0 + if final_x + final_w > img_w: + final_x = img_w - final_w + if final_y + final_h > img_h: + final_y = img_h - final_h final_x, final_y, final_w, final_h = int(final_x), int(final_y), int(final_w), int(final_h) diff --git a/tests/unit/test_inpainting_utils.py b/tests/unit/test_inpainting_utils.py index 931854332..e7f0736d4 100644 --- a/tests/unit/test_inpainting_utils.py +++ b/tests/unit/test_inpainting_utils.py @@ -1,13 +1,65 @@ import pytest import torch -from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult +from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, \ + CompositeCroppedAndFittedInpaintResult, CompositeContext + +TEST_SCENARIOS = [ + # A standard, centered case with no complex adjustments. + pytest.param( + dict( + test_id="standard_sd15_center", + mask_rect=(400, 400, 200, 200), # y, x, h, w + margin="64", + resolutions="SD1.5", + expected_cropped_shape=(512, 512), + expected_context=CompositeContext(x=336, y=336, width=328, height=328) + ), + id="standard_sd15_center" + ), + # The user-described wide-margin case. + pytest.param( + dict( + test_id="wide_ideogram_right_edge", + mask_rect=(900, 950, 32, 32), + margin="64 64 64 400", + resolutions="Ideogram", + expected_cropped_shape=(512, 1536), # Should select 1536x512 (AR=3.0) + expected_context=CompositeContext(x=544, y=836, width=480, height=160) + ), + id="wide_ideogram_right_edge" + ), + # A new test for a tall mask, forcing a ~1:3 aspect ratio. + pytest.param( + dict( + test_id="tall_ideogram_left_edge", + mask_rect=(200, 20, 200, 50), + margin="100", + resolutions="Ideogram", + expected_cropped_shape=(1536, 640), # Should select 640x1536 (AR=0.416) + expected_context=CompositeContext(x=0, y=96, width=170, height=408) + ), + id="tall_ideogram_left_edge" + ), + # A test where the covering rectangle must be shifted to stay in bounds. + pytest.param( + dict( + test_id="shift_to_fit", + mask_rect=(10, 10, 150, 50), + margin="40", + resolutions="Ideogram", + expected_cropped_shape=(1408, 704), # AR is exactly 0.5 + expected_context=CompositeContext(x=0, y=0, width=100, height=200) + ), + id="shift_to_fit" + ) +] def create_circle_mask(height, width, center_y, center_x, radius): Y, X = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") distance = torch.sqrt((Y - center_y) ** 2 + (X - center_x) ** 2) - return (distance >= radius).float().unsqueeze(0) + return (distance < radius).float().unsqueeze(0) @pytest.fixture @@ -24,8 +76,8 @@ def image_1024() -> torch.Tensor: @pytest.fixture def rect_mask() -> torch.Tensor: - mask = torch.ones(1, 256, 256) - mask[:, 100:150, 80:180] = 0.0 + mask = torch.zeros(1, 256, 256) + mask[:, 100:150, 80:180] = 1.0 return mask @@ -56,16 +108,19 @@ def test_end_to_end_composition(request, sample_image, mask_fixture, margin): crop_node = CropAndFitInpaintToDiffusionSize() composite_node = CompositeCroppedAndFittedInpaintResult() - cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin) + cropped_img, cropped_mask, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin) h, w = cropped_img.shape[1:3] blue_color = torch.tensor([0.1, 0.2, 0.9]).view(1, 1, 1, 3) inpainted_sim = blue_color.expand(1, h, w, 3) + # Inpaint the cropped region with the new color + inpainted_cropped = cropped_img * (1 - cropped_mask.unsqueeze(-1)) + inpainted_sim * cropped_mask.unsqueeze(-1) + final_image, = composite_node.composite_result( source_image=sample_image, source_mask=mask, - inpainted_image=inpainted_sim, + inpainted_image=inpainted_cropped, composite_context=context ) @@ -73,8 +128,10 @@ def test_end_to_end_composition(request, sample_image, mask_fixture, margin): bool_mask = mask.squeeze(0).bool() + # The area inside the mask should be blue assert torch.allclose(final_image[0][bool_mask], blue_color.squeeze(), atol=1e-2) - assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask]) + # The area outside the mask should be unchanged + assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask], atol=1e-2) def test_wide_ideogram_composite(image_1024): @@ -88,16 +145,19 @@ def test_wide_ideogram_composite(image_1024): margin = "64 64 64 400" - cropped_img, _, context = crop_node.crop_and_fit(source_image, mask, "Ideogram", margin) + cropped_img, cropped_mask, context = crop_node.crop_and_fit(source_image, mask, "Ideogram", margin) assert cropped_img.shape[1:3] == (512, 1536) green_color = torch.tensor([0.1, 0.9, 0.2]).view(1, 1, 1, 3) - inpainted_sim = green_color.expand(1, 512, 1536, 3) + h, w = cropped_img.shape[1:3] + inpainted_sim = green_color.expand(1, h, w, 3) + + inpainted_cropped = cropped_img * (1 - cropped_mask.unsqueeze(-1)) + inpainted_sim * cropped_mask.unsqueeze(-1) final_image, = composite_node.composite_result( source_image=source_image, source_mask=mask, - inpainted_image=inpainted_sim, + inpainted_image=inpainted_cropped, composite_context=context ) @@ -109,4 +169,55 @@ def test_wide_ideogram_composite(image_1024): assert torch.all(final_pixels[:, 1] > final_pixels[:, 0]) assert torch.all(final_pixels[:, 1] > final_pixels[:, 2]) - assert torch.allclose(final_image[0, 916, 940, :], source_image[0, 916, 940, :]) + assert torch.allclose(final_image[0][~bool_mask], source_image[0][~bool_mask], atol=1e-2) + + +@pytest.mark.parametrize("scenario", TEST_SCENARIOS) +def test_end_to_end_scenarios(image_1024, scenario): + """ + A single, comprehensive test to validate the full node pipeline against various scenarios. + """ + source_image = image_1024 + + # 1. Setup based on the scenario + mask = torch.zeros_like(source_image[..., 0]) + y, x, h, w = scenario["mask_rect"] + mask[:, y:y + h, x:x + w] = 1.0 # Area to inpaint is 1 + + crop_node = CropAndFitInpaintToDiffusionSize() + composite_node = CompositeCroppedAndFittedInpaintResult() + + # 2. Run the first node + cropped_img, cropped_mask, context = crop_node.crop_and_fit( + source_image, mask, scenario["resolutions"], scenario["margin"] + ) + + # 3. Assert the outputs of the first node + assert cropped_img.shape[1:3] == scenario["expected_cropped_shape"] + assert context == scenario["expected_context"] + + # 4. Simulate inpainting + green_color = torch.tensor([0.1, 0.9, 0.2]).view(1, 1, 1, 3) + sim_h, sim_w = cropped_img.shape[1:3] + inpainted_sim = green_color.expand(1, sim_h, sim_w, -1) + + # Inpaint the cropped region with the new color, respecting the mask feathering + inpainted_cropped = cropped_img * (1 - cropped_mask.unsqueeze(-1)) + inpainted_sim * cropped_mask.unsqueeze(-1) + + # 5. Run the second node + final_image, = composite_node.composite_result( + source_image=source_image, + source_mask=mask, + inpainted_image=inpainted_cropped, + composite_context=context + ) + + # 6. Assert the final composited image + assert final_image.shape == source_image.shape + + # Check that the area to be inpainted (mask==1) is now the green color. + bool_mask_to_inpaint = (mask.squeeze(0) > 0.0) + assert torch.allclose(final_image[0][bool_mask_to_inpaint], green_color.squeeze(), atol=1e-2) + + # Check that the area that was not masked is completely unchanged. + assert torch.allclose(final_image[0][~bool_mask_to_inpaint], source_image[0][~bool_mask_to_inpaint], atol=1e-2)