mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
fix inpainting alignment issue
This commit is contained in:
parent
8041b1b54d
commit
41dc6ec7fa
@ -1,7 +1,6 @@
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
@ -84,11 +83,12 @@ class CropAndFitInpaintToDiffusionSize(CustomNode):
|
||||
CATEGORY = "inpaint"
|
||||
|
||||
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str):
|
||||
if mask.max() <= 0:
|
||||
raise ValueError("Mask is empty.")
|
||||
if mask.max() == 0.0:
|
||||
raise ValueError("Mask is empty (all black).")
|
||||
|
||||
mask_coords = torch.nonzero(mask)
|
||||
if mask_coords.numel() == 0:
|
||||
raise ValueError("Mask is empty.")
|
||||
raise ValueError("Mask is empty (all black).")
|
||||
|
||||
y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2]
|
||||
y_min, x_min = y_coords.min().item(), x_coords.min().item()
|
||||
@ -99,8 +99,11 @@ class CropAndFitInpaintToDiffusionSize(CustomNode):
|
||||
x_end_expanded, y_end_expanded = x_max + 1 + right_m, y_max + 1 + bottom_m
|
||||
|
||||
img_h, img_w = image.shape[1:3]
|
||||
clamped_x_start, clamped_y_start = max(0, x_start_expanded), max(0, y_start_expanded)
|
||||
clamped_x_end, clamped_y_end = min(img_w, x_end_expanded), min(img_h, y_end_expanded)
|
||||
|
||||
clamped_x_start = max(0, x_start_expanded)
|
||||
clamped_y_start = max(0, y_start_expanded)
|
||||
clamped_x_end = min(img_w, x_end_expanded)
|
||||
clamped_y_end = min(img_h, y_end_expanded)
|
||||
|
||||
initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start
|
||||
if initial_w <= 0 or initial_h <= 0:
|
||||
@ -112,15 +115,30 @@ class CropAndFitInpaintToDiffusionSize(CustomNode):
|
||||
target_ar = target_res[0] / target_res[1]
|
||||
|
||||
current_ar = initial_w / initial_h
|
||||
final_x, final_y = float(clamped_x_start), float(clamped_y_start)
|
||||
final_w, final_h = float(initial_w), float(initial_h)
|
||||
|
||||
if current_ar > target_ar:
|
||||
final_w = initial_h * target_ar
|
||||
final_x += (initial_w - final_w) / 2
|
||||
cover_w, cover_h = float(initial_w), float(initial_w) / target_ar
|
||||
else:
|
||||
final_h = initial_w / target_ar
|
||||
final_y += (initial_h - final_h) / 2
|
||||
cover_h, cover_w = float(initial_h), float(initial_h) * target_ar
|
||||
|
||||
if cover_w > img_w or cover_h > img_h:
|
||||
final_x, final_y, final_w, final_h = 0, 0, img_w, img_h
|
||||
full_img_ar = img_w / img_h
|
||||
diffs_full = [(abs(res[0] / res[1] - full_img_ar), res) for res in supported_resolutions]
|
||||
target_res = min(diffs_full, key=lambda x: x[0])[1]
|
||||
else:
|
||||
center_x = clamped_x_start + initial_w / 2
|
||||
center_y = clamped_y_start + initial_h / 2
|
||||
final_x, final_y = center_x - cover_w / 2, center_y - cover_h / 2
|
||||
final_w, final_h = cover_w, cover_h
|
||||
|
||||
if final_x < 0:
|
||||
final_x = 0
|
||||
if final_y < 0:
|
||||
final_y = 0
|
||||
if final_x + final_w > img_w:
|
||||
final_x = img_w - final_w
|
||||
if final_y + final_h > img_h:
|
||||
final_y = img_h - final_h
|
||||
|
||||
final_x, final_y, final_w, final_h = int(final_x), int(final_y), int(final_w), int(final_h)
|
||||
|
||||
|
||||
@ -1,13 +1,65 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult
|
||||
from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, \
|
||||
CompositeCroppedAndFittedInpaintResult, CompositeContext
|
||||
|
||||
TEST_SCENARIOS = [
|
||||
# A standard, centered case with no complex adjustments.
|
||||
pytest.param(
|
||||
dict(
|
||||
test_id="standard_sd15_center",
|
||||
mask_rect=(400, 400, 200, 200), # y, x, h, w
|
||||
margin="64",
|
||||
resolutions="SD1.5",
|
||||
expected_cropped_shape=(512, 512),
|
||||
expected_context=CompositeContext(x=336, y=336, width=328, height=328)
|
||||
),
|
||||
id="standard_sd15_center"
|
||||
),
|
||||
# The user-described wide-margin case.
|
||||
pytest.param(
|
||||
dict(
|
||||
test_id="wide_ideogram_right_edge",
|
||||
mask_rect=(900, 950, 32, 32),
|
||||
margin="64 64 64 400",
|
||||
resolutions="Ideogram",
|
||||
expected_cropped_shape=(512, 1536), # Should select 1536x512 (AR=3.0)
|
||||
expected_context=CompositeContext(x=544, y=836, width=480, height=160)
|
||||
),
|
||||
id="wide_ideogram_right_edge"
|
||||
),
|
||||
# A new test for a tall mask, forcing a ~1:3 aspect ratio.
|
||||
pytest.param(
|
||||
dict(
|
||||
test_id="tall_ideogram_left_edge",
|
||||
mask_rect=(200, 20, 200, 50),
|
||||
margin="100",
|
||||
resolutions="Ideogram",
|
||||
expected_cropped_shape=(1536, 640), # Should select 640x1536 (AR=0.416)
|
||||
expected_context=CompositeContext(x=0, y=96, width=170, height=408)
|
||||
),
|
||||
id="tall_ideogram_left_edge"
|
||||
),
|
||||
# A test where the covering rectangle must be shifted to stay in bounds.
|
||||
pytest.param(
|
||||
dict(
|
||||
test_id="shift_to_fit",
|
||||
mask_rect=(10, 10, 150, 50),
|
||||
margin="40",
|
||||
resolutions="Ideogram",
|
||||
expected_cropped_shape=(1408, 704), # AR is exactly 0.5
|
||||
expected_context=CompositeContext(x=0, y=0, width=100, height=200)
|
||||
),
|
||||
id="shift_to_fit"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def create_circle_mask(height, width, center_y, center_x, radius):
|
||||
Y, X = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
|
||||
distance = torch.sqrt((Y - center_y) ** 2 + (X - center_x) ** 2)
|
||||
return (distance >= radius).float().unsqueeze(0)
|
||||
return (distance < radius).float().unsqueeze(0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -24,8 +76,8 @@ def image_1024() -> torch.Tensor:
|
||||
|
||||
@pytest.fixture
|
||||
def rect_mask() -> torch.Tensor:
|
||||
mask = torch.ones(1, 256, 256)
|
||||
mask[:, 100:150, 80:180] = 0.0
|
||||
mask = torch.zeros(1, 256, 256)
|
||||
mask[:, 100:150, 80:180] = 1.0
|
||||
return mask
|
||||
|
||||
|
||||
@ -56,16 +108,19 @@ def test_end_to_end_composition(request, sample_image, mask_fixture, margin):
|
||||
crop_node = CropAndFitInpaintToDiffusionSize()
|
||||
composite_node = CompositeCroppedAndFittedInpaintResult()
|
||||
|
||||
cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin)
|
||||
cropped_img, cropped_mask, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin)
|
||||
|
||||
h, w = cropped_img.shape[1:3]
|
||||
blue_color = torch.tensor([0.1, 0.2, 0.9]).view(1, 1, 1, 3)
|
||||
inpainted_sim = blue_color.expand(1, h, w, 3)
|
||||
|
||||
# Inpaint the cropped region with the new color
|
||||
inpainted_cropped = cropped_img * (1 - cropped_mask.unsqueeze(-1)) + inpainted_sim * cropped_mask.unsqueeze(-1)
|
||||
|
||||
final_image, = composite_node.composite_result(
|
||||
source_image=sample_image,
|
||||
source_mask=mask,
|
||||
inpainted_image=inpainted_sim,
|
||||
inpainted_image=inpainted_cropped,
|
||||
composite_context=context
|
||||
)
|
||||
|
||||
@ -73,8 +128,10 @@ def test_end_to_end_composition(request, sample_image, mask_fixture, margin):
|
||||
|
||||
bool_mask = mask.squeeze(0).bool()
|
||||
|
||||
# The area inside the mask should be blue
|
||||
assert torch.allclose(final_image[0][bool_mask], blue_color.squeeze(), atol=1e-2)
|
||||
assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask])
|
||||
# The area outside the mask should be unchanged
|
||||
assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask], atol=1e-2)
|
||||
|
||||
|
||||
def test_wide_ideogram_composite(image_1024):
|
||||
@ -88,16 +145,19 @@ def test_wide_ideogram_composite(image_1024):
|
||||
|
||||
margin = "64 64 64 400"
|
||||
|
||||
cropped_img, _, context = crop_node.crop_and_fit(source_image, mask, "Ideogram", margin)
|
||||
cropped_img, cropped_mask, context = crop_node.crop_and_fit(source_image, mask, "Ideogram", margin)
|
||||
assert cropped_img.shape[1:3] == (512, 1536)
|
||||
|
||||
green_color = torch.tensor([0.1, 0.9, 0.2]).view(1, 1, 1, 3)
|
||||
inpainted_sim = green_color.expand(1, 512, 1536, 3)
|
||||
h, w = cropped_img.shape[1:3]
|
||||
inpainted_sim = green_color.expand(1, h, w, 3)
|
||||
|
||||
inpainted_cropped = cropped_img * (1 - cropped_mask.unsqueeze(-1)) + inpainted_sim * cropped_mask.unsqueeze(-1)
|
||||
|
||||
final_image, = composite_node.composite_result(
|
||||
source_image=source_image,
|
||||
source_mask=mask,
|
||||
inpainted_image=inpainted_sim,
|
||||
inpainted_image=inpainted_cropped,
|
||||
composite_context=context
|
||||
)
|
||||
|
||||
@ -109,4 +169,55 @@ def test_wide_ideogram_composite(image_1024):
|
||||
assert torch.all(final_pixels[:, 1] > final_pixels[:, 0])
|
||||
assert torch.all(final_pixels[:, 1] > final_pixels[:, 2])
|
||||
|
||||
assert torch.allclose(final_image[0, 916, 940, :], source_image[0, 916, 940, :])
|
||||
assert torch.allclose(final_image[0][~bool_mask], source_image[0][~bool_mask], atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scenario", TEST_SCENARIOS)
|
||||
def test_end_to_end_scenarios(image_1024, scenario):
|
||||
"""
|
||||
A single, comprehensive test to validate the full node pipeline against various scenarios.
|
||||
"""
|
||||
source_image = image_1024
|
||||
|
||||
# 1. Setup based on the scenario
|
||||
mask = torch.zeros_like(source_image[..., 0])
|
||||
y, x, h, w = scenario["mask_rect"]
|
||||
mask[:, y:y + h, x:x + w] = 1.0 # Area to inpaint is 1
|
||||
|
||||
crop_node = CropAndFitInpaintToDiffusionSize()
|
||||
composite_node = CompositeCroppedAndFittedInpaintResult()
|
||||
|
||||
# 2. Run the first node
|
||||
cropped_img, cropped_mask, context = crop_node.crop_and_fit(
|
||||
source_image, mask, scenario["resolutions"], scenario["margin"]
|
||||
)
|
||||
|
||||
# 3. Assert the outputs of the first node
|
||||
assert cropped_img.shape[1:3] == scenario["expected_cropped_shape"]
|
||||
assert context == scenario["expected_context"]
|
||||
|
||||
# 4. Simulate inpainting
|
||||
green_color = torch.tensor([0.1, 0.9, 0.2]).view(1, 1, 1, 3)
|
||||
sim_h, sim_w = cropped_img.shape[1:3]
|
||||
inpainted_sim = green_color.expand(1, sim_h, sim_w, -1)
|
||||
|
||||
# Inpaint the cropped region with the new color, respecting the mask feathering
|
||||
inpainted_cropped = cropped_img * (1 - cropped_mask.unsqueeze(-1)) + inpainted_sim * cropped_mask.unsqueeze(-1)
|
||||
|
||||
# 5. Run the second node
|
||||
final_image, = composite_node.composite_result(
|
||||
source_image=source_image,
|
||||
source_mask=mask,
|
||||
inpainted_image=inpainted_cropped,
|
||||
composite_context=context
|
||||
)
|
||||
|
||||
# 6. Assert the final composited image
|
||||
assert final_image.shape == source_image.shape
|
||||
|
||||
# Check that the area to be inpainted (mask==1) is now the green color.
|
||||
bool_mask_to_inpaint = (mask.squeeze(0) > 0.0)
|
||||
assert torch.allclose(final_image[0][bool_mask_to_inpaint], green_color.squeeze(), atol=1e-2)
|
||||
|
||||
# Check that the area that was not masked is completely unchanged.
|
||||
assert torch.allclose(final_image[0][~bool_mask_to_inpaint], source_image[0][~bool_mask_to_inpaint], atol=1e-2)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user