fix inpainting alignment issue

This commit is contained in:
doctorpangloss 2025-06-24 12:27:06 -07:00
parent 8041b1b54d
commit 41dc6ec7fa
2 changed files with 154 additions and 25 deletions

View File

@ -1,7 +1,6 @@
from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
from typing import NamedTuple, Optional
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
from comfy.nodes.package_typing import CustomNode
@ -84,11 +83,12 @@ class CropAndFitInpaintToDiffusionSize(CustomNode):
CATEGORY = "inpaint"
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str):
if mask.max() <= 0:
raise ValueError("Mask is empty.")
if mask.max() == 0.0:
raise ValueError("Mask is empty (all black).")
mask_coords = torch.nonzero(mask)
if mask_coords.numel() == 0:
raise ValueError("Mask is empty.")
raise ValueError("Mask is empty (all black).")
y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2]
y_min, x_min = y_coords.min().item(), x_coords.min().item()
@ -99,8 +99,11 @@ class CropAndFitInpaintToDiffusionSize(CustomNode):
x_end_expanded, y_end_expanded = x_max + 1 + right_m, y_max + 1 + bottom_m
img_h, img_w = image.shape[1:3]
clamped_x_start, clamped_y_start = max(0, x_start_expanded), max(0, y_start_expanded)
clamped_x_end, clamped_y_end = min(img_w, x_end_expanded), min(img_h, y_end_expanded)
clamped_x_start = max(0, x_start_expanded)
clamped_y_start = max(0, y_start_expanded)
clamped_x_end = min(img_w, x_end_expanded)
clamped_y_end = min(img_h, y_end_expanded)
initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start
if initial_w <= 0 or initial_h <= 0:
@ -112,15 +115,30 @@ class CropAndFitInpaintToDiffusionSize(CustomNode):
target_ar = target_res[0] / target_res[1]
current_ar = initial_w / initial_h
final_x, final_y = float(clamped_x_start), float(clamped_y_start)
final_w, final_h = float(initial_w), float(initial_h)
if current_ar > target_ar:
final_w = initial_h * target_ar
final_x += (initial_w - final_w) / 2
cover_w, cover_h = float(initial_w), float(initial_w) / target_ar
else:
final_h = initial_w / target_ar
final_y += (initial_h - final_h) / 2
cover_h, cover_w = float(initial_h), float(initial_h) * target_ar
if cover_w > img_w or cover_h > img_h:
final_x, final_y, final_w, final_h = 0, 0, img_w, img_h
full_img_ar = img_w / img_h
diffs_full = [(abs(res[0] / res[1] - full_img_ar), res) for res in supported_resolutions]
target_res = min(diffs_full, key=lambda x: x[0])[1]
else:
center_x = clamped_x_start + initial_w / 2
center_y = clamped_y_start + initial_h / 2
final_x, final_y = center_x - cover_w / 2, center_y - cover_h / 2
final_w, final_h = cover_w, cover_h
if final_x < 0:
final_x = 0
if final_y < 0:
final_y = 0
if final_x + final_w > img_w:
final_x = img_w - final_w
if final_y + final_h > img_h:
final_y = img_h - final_h
final_x, final_y, final_w, final_h = int(final_x), int(final_y), int(final_w), int(final_h)

View File

@ -1,13 +1,65 @@
import pytest
import torch
from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult
from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, \
CompositeCroppedAndFittedInpaintResult, CompositeContext
TEST_SCENARIOS = [
# A standard, centered case with no complex adjustments.
pytest.param(
dict(
test_id="standard_sd15_center",
mask_rect=(400, 400, 200, 200), # y, x, h, w
margin="64",
resolutions="SD1.5",
expected_cropped_shape=(512, 512),
expected_context=CompositeContext(x=336, y=336, width=328, height=328)
),
id="standard_sd15_center"
),
# The user-described wide-margin case.
pytest.param(
dict(
test_id="wide_ideogram_right_edge",
mask_rect=(900, 950, 32, 32),
margin="64 64 64 400",
resolutions="Ideogram",
expected_cropped_shape=(512, 1536), # Should select 1536x512 (AR=3.0)
expected_context=CompositeContext(x=544, y=836, width=480, height=160)
),
id="wide_ideogram_right_edge"
),
# A new test for a tall mask, forcing a ~1:3 aspect ratio.
pytest.param(
dict(
test_id="tall_ideogram_left_edge",
mask_rect=(200, 20, 200, 50),
margin="100",
resolutions="Ideogram",
expected_cropped_shape=(1536, 640), # Should select 640x1536 (AR=0.416)
expected_context=CompositeContext(x=0, y=96, width=170, height=408)
),
id="tall_ideogram_left_edge"
),
# A test where the covering rectangle must be shifted to stay in bounds.
pytest.param(
dict(
test_id="shift_to_fit",
mask_rect=(10, 10, 150, 50),
margin="40",
resolutions="Ideogram",
expected_cropped_shape=(1408, 704), # AR is exactly 0.5
expected_context=CompositeContext(x=0, y=0, width=100, height=200)
),
id="shift_to_fit"
)
]
def create_circle_mask(height, width, center_y, center_x, radius):
Y, X = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
distance = torch.sqrt((Y - center_y) ** 2 + (X - center_x) ** 2)
return (distance >= radius).float().unsqueeze(0)
return (distance < radius).float().unsqueeze(0)
@pytest.fixture
@ -24,8 +76,8 @@ def image_1024() -> torch.Tensor:
@pytest.fixture
def rect_mask() -> torch.Tensor:
mask = torch.ones(1, 256, 256)
mask[:, 100:150, 80:180] = 0.0
mask = torch.zeros(1, 256, 256)
mask[:, 100:150, 80:180] = 1.0
return mask
@ -56,16 +108,19 @@ def test_end_to_end_composition(request, sample_image, mask_fixture, margin):
crop_node = CropAndFitInpaintToDiffusionSize()
composite_node = CompositeCroppedAndFittedInpaintResult()
cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin)
cropped_img, cropped_mask, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin)
h, w = cropped_img.shape[1:3]
blue_color = torch.tensor([0.1, 0.2, 0.9]).view(1, 1, 1, 3)
inpainted_sim = blue_color.expand(1, h, w, 3)
# Inpaint the cropped region with the new color
inpainted_cropped = cropped_img * (1 - cropped_mask.unsqueeze(-1)) + inpainted_sim * cropped_mask.unsqueeze(-1)
final_image, = composite_node.composite_result(
source_image=sample_image,
source_mask=mask,
inpainted_image=inpainted_sim,
inpainted_image=inpainted_cropped,
composite_context=context
)
@ -73,8 +128,10 @@ def test_end_to_end_composition(request, sample_image, mask_fixture, margin):
bool_mask = mask.squeeze(0).bool()
# The area inside the mask should be blue
assert torch.allclose(final_image[0][bool_mask], blue_color.squeeze(), atol=1e-2)
assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask])
# The area outside the mask should be unchanged
assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask], atol=1e-2)
def test_wide_ideogram_composite(image_1024):
@ -88,16 +145,19 @@ def test_wide_ideogram_composite(image_1024):
margin = "64 64 64 400"
cropped_img, _, context = crop_node.crop_and_fit(source_image, mask, "Ideogram", margin)
cropped_img, cropped_mask, context = crop_node.crop_and_fit(source_image, mask, "Ideogram", margin)
assert cropped_img.shape[1:3] == (512, 1536)
green_color = torch.tensor([0.1, 0.9, 0.2]).view(1, 1, 1, 3)
inpainted_sim = green_color.expand(1, 512, 1536, 3)
h, w = cropped_img.shape[1:3]
inpainted_sim = green_color.expand(1, h, w, 3)
inpainted_cropped = cropped_img * (1 - cropped_mask.unsqueeze(-1)) + inpainted_sim * cropped_mask.unsqueeze(-1)
final_image, = composite_node.composite_result(
source_image=source_image,
source_mask=mask,
inpainted_image=inpainted_sim,
inpainted_image=inpainted_cropped,
composite_context=context
)
@ -109,4 +169,55 @@ def test_wide_ideogram_composite(image_1024):
assert torch.all(final_pixels[:, 1] > final_pixels[:, 0])
assert torch.all(final_pixels[:, 1] > final_pixels[:, 2])
assert torch.allclose(final_image[0, 916, 940, :], source_image[0, 916, 940, :])
assert torch.allclose(final_image[0][~bool_mask], source_image[0][~bool_mask], atol=1e-2)
@pytest.mark.parametrize("scenario", TEST_SCENARIOS)
def test_end_to_end_scenarios(image_1024, scenario):
"""
A single, comprehensive test to validate the full node pipeline against various scenarios.
"""
source_image = image_1024
# 1. Setup based on the scenario
mask = torch.zeros_like(source_image[..., 0])
y, x, h, w = scenario["mask_rect"]
mask[:, y:y + h, x:x + w] = 1.0 # Area to inpaint is 1
crop_node = CropAndFitInpaintToDiffusionSize()
composite_node = CompositeCroppedAndFittedInpaintResult()
# 2. Run the first node
cropped_img, cropped_mask, context = crop_node.crop_and_fit(
source_image, mask, scenario["resolutions"], scenario["margin"]
)
# 3. Assert the outputs of the first node
assert cropped_img.shape[1:3] == scenario["expected_cropped_shape"]
assert context == scenario["expected_context"]
# 4. Simulate inpainting
green_color = torch.tensor([0.1, 0.9, 0.2]).view(1, 1, 1, 3)
sim_h, sim_w = cropped_img.shape[1:3]
inpainted_sim = green_color.expand(1, sim_h, sim_w, -1)
# Inpaint the cropped region with the new color, respecting the mask feathering
inpainted_cropped = cropped_img * (1 - cropped_mask.unsqueeze(-1)) + inpainted_sim * cropped_mask.unsqueeze(-1)
# 5. Run the second node
final_image, = composite_node.composite_result(
source_image=source_image,
source_mask=mask,
inpainted_image=inpainted_cropped,
composite_context=context
)
# 6. Assert the final composited image
assert final_image.shape == source_image.shape
# Check that the area to be inpainted (mask==1) is now the green color.
bool_mask_to_inpaint = (mask.squeeze(0) > 0.0)
assert torch.allclose(final_image[0][bool_mask_to_inpaint], green_color.squeeze(), atol=1e-2)
# Check that the area that was not masked is completely unchanged.
assert torch.allclose(final_image[0][~bool_mask_to_inpaint], source_image[0][~bool_mask_to_inpaint], atol=1e-2)