From 396a2ef3d311dbf67c7ca4511e45de09be8192e4 Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Fri, 6 Jun 2025 22:25:22 -0700 Subject: [PATCH] wip inpainting fixes and ideogram now takes a mask that is more convention from the POV of comfyui --- comfy_extras/nodes/nodes_ideogram.py | 2 +- comfy_extras/nodes/nodes_inpainting.py | 279 +++++++++---------------- tests/unit/test_ideogram_nodes.py | 2 - tests/unit/test_inpainting_utils.py | 127 ++--------- 4 files changed, 112 insertions(+), 298 deletions(-) diff --git a/comfy_extras/nodes/nodes_ideogram.py b/comfy_extras/nodes/nodes_ideogram.py index 2339198ac..06886fb98 100644 --- a/comfy_extras/nodes/nodes_ideogram.py +++ b/comfy_extras/nodes/nodes_ideogram.py @@ -163,7 +163,7 @@ class IdeogramEdit(CustomNode): headers = {"Api-Key": api_key} image_responses = [] for mask_tensor, image_tensor in zip(torch.unbind(masks), torch.unbind(images)): - mask_tensor, = MaskToImage().mask_to_image(mask=mask_tensor) + mask_tensor, = MaskToImage().mask_to_image(mask=1. - mask_tensor) image_pil, mask_pil = tensor2pil(image_tensor), tensor2pil(mask_tensor) image_bytes, mask_bytes = BytesIO(), BytesIO() diff --git a/comfy_extras/nodes/nodes_inpainting.py b/comfy_extras/nodes/nodes_inpainting.py index 5aba49af1..3abc5f198 100644 --- a/comfy_extras/nodes/nodes_inpainting.py +++ b/comfy_extras/nodes/nodes_inpainting.py @@ -1,244 +1,151 @@ import torch -import torch import torch.nn.functional as F from comfy.component_model.tensor_types import MaskBatch -from comfy_extras.constants.resolutions import RESOLUTION_NAMES -from comfy_extras.nodes.nodes_images import ImageResize +from comfy_extras.constants.resolutions import ( + RESOLUTION_NAMES, SDXL_SD3_FLUX_RESOLUTIONS, SD_RESOLUTIONS, LTVX_RESOLUTIONS, + IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS, + WAN_VIDEO_14B_RESOLUTIONS, WAN_VIDEO_1_3B_RESOLUTIONS, + WAN_VIDEO_14B_EXTENDED_RESOLUTIONS +) -# Helper function from the context to composite images def composite(destination, source, x, y, mask=None, multiplier=1, resize_source=False): - # This function is adapted from the provided context code source = source.to(destination.device) if resize_source: - source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") - - # Ensure source has the same batch size as destination + source = F.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") if source.shape[0] != destination.shape[0]: source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1) - x = int(x) - y = int(y) - - left, top = (x, y) - right, bottom = (left + source.shape[3], top + source.shape[2]) + x, y = int(x), int(y) + left, top = x, y + right, bottom = left + source.shape[3], top + source.shape[2] if mask is None: - # If no mask is provided, create a full-coverage mask mask = torch.ones_like(source) else: - # Ensure mask is on the correct device and is the correct size mask = mask.to(destination.device, copy=True) - # Check if the mask is 2D (H, W) or 3D (B, H, W) and unsqueeze if necessary - if mask.dim() == 2: - mask = mask.unsqueeze(0) - if mask.dim() == 3: - mask = mask.unsqueeze(1) # Add channel dimension - mask = torch.nn.functional.interpolate(mask, size=(source.shape[2], source.shape[3]), mode="bilinear") + if mask.dim() == 2: mask = mask.unsqueeze(0) + if mask.dim() == 3: mask = mask.unsqueeze(1) if mask.shape[0] != source.shape[0]: mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1) - # Define the bounds of the overlapping area - dest_left = max(0, left) - dest_top = max(0, top) - dest_right = min(destination.shape[3], right) - dest_bottom = min(destination.shape[2], bottom) + dest_left, dest_top = max(0, left), max(0, top) + dest_right, dest_bottom = min(destination.shape[3], right), min(destination.shape[2], bottom) - # If there is no overlap, return the original destination - if dest_right <= dest_left or dest_bottom <= dest_top: - return destination + if dest_right <= dest_left or dest_bottom <= dest_top: return destination - # Calculate the source coordinates corresponding to the overlap - src_left = dest_left - left - src_top = dest_top - top - src_right = dest_right - left - src_bottom = dest_bottom - top + src_left, src_top = dest_left - left, dest_top - top + src_right, src_bottom = dest_right - left, dest_bottom - # Crop the relevant portions of the destination, source, and mask destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right] source_portion = source[:, :, src_top:src_bottom, src_left:src_right] mask_portion = mask[:, :, src_top:src_bottom, src_left:src_right] - inverse_mask_portion = 1.0 - mask_portion - - # Perform the composition - blended_portion = (source_portion * mask_portion) + (destination_portion * inverse_mask_portion) - - # Place the blended portion back into the destination + blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion)) destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion - return destination def parse_margin(margin_str: str) -> tuple[int, int, int, int]: - """Parses a CSS-style margin string.""" parts = [int(p) for p in margin_str.strip().split()] - if len(parts) == 1: - return parts[0], parts[0], parts[0], parts[0] - if len(parts) == 2: - return parts[0], parts[1], parts[0], parts[1] - if len(parts) == 3: - return parts[0], parts[1], parts[2], parts[1] - if len(parts) == 4: - return parts[0], parts[1], parts[2], parts[3] - raise ValueError("Invalid margin format. Use 1 to 4 integer values.") + if len(parts) == 1: return parts[0], parts[0], parts[0], parts[0] + if len(parts) == 2: return parts[0], parts[1], parts[0], parts[1] + if len(parts) == 3: return parts[0], parts[1], parts[2], parts[1] + if len(parts) == 4: return parts[0], parts[1], parts[2], parts[3] + raise ValueError("Invalid margin format.") class CropAndFitInpaintToDiffusionSize: @classmethod def INPUT_TYPES(cls): - return { - "required": { - "image": ("IMAGE",), - "mask": ("MASK",), - "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), - "margin": ("STRING", {"default": "64"}), - "overflow": ("BOOLEAN", {"default": True}), - } - } + return {"required": {"image": ("IMAGE",), "mask": ("MASK",), "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), "margin": ("STRING", {"default": "64"}), "overflow": ("BOOLEAN", {"default": True}), }} - RETURN_TYPES = ("IMAGE", "MASK", "COMBO[INT]") - RETURN_NAMES = ("image", "mask", "composite_context") - FUNCTION = "crop_and_fit" - CATEGORY = "inpaint" + RETURN_TYPES, RETURN_NAMES, FUNCTION, CATEGORY = ("IMAGE", "MASK", "COMBO[INT]"), ("image", "mask", "composite_context"), "crop_and_fit", "inpaint" - def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, overflow: bool): - # 1. Find bounding box of the mask - if mask.max() <= 0: - raise ValueError("Mask is empty, cannot determine bounding box.") - - # Find the coordinates of non-zero mask pixels - mask_coords = torch.nonzero(mask[0]) # Assuming single batch for mask - if mask_coords.numel() == 0: - raise ValueError("Mask is empty, cannot determine bounding box.") - - y_min, x_min = mask_coords.min(dim=0).values + def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, overflow: bool, aspect_ratio_tolerance=0.05): + if mask.max() <= 0: raise ValueError("Mask is empty.") + mask_coords = torch.nonzero(mask[0]); + if mask_coords.numel() == 0: raise ValueError("Mask is empty.") + y_min, x_min = mask_coords.min(dim=0).values; y_max, x_max = mask_coords.max(dim=0).values + top_m, right_m, bottom_m, left_m = parse_margin(margin) + x_start_init, y_start_init = x_min.item() - left_m, y_min.item() - top_m + x_end_init, y_end_init = x_max.item() + 1 + right_m, y_max.item() + 1 + bottom_m + img_h, img_w = image.shape[1:3] + pad_image, pad_mask = image, mask + x_start_crop, y_start_crop = x_start_init, y_start_init + x_end_crop, y_end_crop = x_end_init, y_end_init + pad_l, pad_t = -min(0, x_start_init), -min(0, y_start_init) + pad_r, pad_b = max(0, x_end_init - img_w), max(0, y_end_init - img_h) + if any([pad_l, pad_t, pad_r, pad_b]) and overflow: + padding = (pad_l, pad_r, pad_t, pad_b) + pad_image = F.pad(image.permute(0, 3, 1, 2), padding, "constant", 0.5).permute(0, 2, 3, 1) + pad_mask = F.pad(mask.unsqueeze(1), padding, "constant", 0).squeeze(1) + x_start_crop += pad_l; + y_start_crop += pad_t; + x_end_crop += pad_l; + y_end_crop += pad_t + else: + x_start_crop, y_start_crop = max(0, x_start_init), max(0, y_start_init) + x_end_crop, y_end_crop = min(img_w, x_end_init), min(img_h, y_end_init) + composite_x, composite_y = (x_start_init if overflow else x_start_crop), (y_start_init if overflow else y_start_crop) + cropped_image = pad_image[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop, :] + cropped_mask = pad_mask[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop] + context = {"x": composite_x, "y": composite_y, "width": cropped_image.shape[2], "height": cropped_image.shape[1]} - # 2. Parse and apply margin - top_margin, right_margin, bottom_margin, left_margin = parse_margin(margin) - - x_start = x_min.item() - left_margin - y_start = y_min.item() - top_margin - x_end = x_max.item() + 1 + right_margin - y_end = y_max.item() + 1 + bottom_margin - - img_height, img_width = image.shape[1:3] - - # Store pre-crop context for the compositor node - context = { - "x": x_start, - "y": y_start, - "width": x_end - x_start, - "height": y_end - y_start - } - - # 3. Handle overflow - padded_image = image - padded_mask = mask - - pad_left = -min(0, x_start) - pad_top = -min(0, y_start) - pad_right = max(0, x_end - img_width) - pad_bottom = max(0, y_end - img_height) - - if any([pad_left, pad_top, pad_right, pad_bottom]): - if not overflow: - # Crop margin to fit within the image - x_start = max(0, x_start) - y_start = max(0, y_start) - x_end = min(img_width, x_end) - y_end = min(img_height, y_end) - else: - # Extend image and mask - padding = (pad_left, pad_right, pad_top, pad_bottom) - # Pad image with gray - padded_image = F.pad(image.permute(0, 3, 1, 2), padding, "constant", 0.5).permute(0, 2, 3, 1) - # Pad mask with zeros - padded_mask = F.pad(mask.unsqueeze(1), padding, "constant", 0).squeeze(1) - - # Adjust coordinates for the new padded space - x_start += pad_left - y_start += pad_top - x_end += pad_left - y_end += pad_top - - # 4. Crop image and mask - cropped_image = padded_image[:, y_start:y_end, x_start:x_end, :] - cropped_mask = padded_mask[:, y_start:y_end, x_start:x_end] - - # 5. Resize to a supported resolution - resizer = ImageResize() - resized_image, = resizer.resize_image(cropped_image, "cover", resolutions, "lanczos") - - # Resize mask similarly. Convert to image-like tensor for resizing. - cropped_mask_as_image = cropped_mask.unsqueeze(-1).repeat(1, 1, 1, 3) - resized_mask_as_image, = resizer.resize_image(cropped_mask_as_image, "cover", resolutions, "lanczos") - # Convert back to a mask (using the red channel) - resized_mask = resized_mask_as_image[:, :, :, 0] - - # Pack context into a list of ints for output - # Format: [x, y, width, height] - composite_context = (context["x"], context["y"], context["width"], context["height"]) - - return (resized_image, resized_mask, composite_context) + rgba_bchw = torch.cat((cropped_image.permute(0, 3, 1, 2), cropped_mask.unsqueeze(1)), dim=1) + res_map = {"SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS} + supported_resolutions = res_map.get(resolutions, SD_RESOLUTIONS) + h, w = cropped_image.shape[1:3] + current_aspect_ratio = w / h + diffs = [(abs(res[0] / res[1] - current_aspect_ratio), res) for res in supported_resolutions] + min_diff = min(diffs, key=lambda x: x[0])[0] + close_res = [res for diff, res in diffs if diff <= min_diff + aspect_ratio_tolerance] + target_res = max(close_res, key=lambda r: r[0] * r[1]) + scale = max(target_res[0] / w, target_res[1] / h) + new_w, new_h = int(w * scale), int(h * scale) + upscaled_rgba = F.interpolate(rgba_bchw, size=(new_h, new_w), mode="bilinear", align_corners=False) + y1, x1 = (new_h - target_res[1]) // 2, (new_w - target_res[0]) // 2 + final_rgba_bchw = upscaled_rgba[:, :, y1:y1 + target_res[1], x1:x1 + target_res[0]] + final_rgba_bhwc = final_rgba_bchw.permute(0, 2, 3, 1) + resized_image = final_rgba_bhwc[..., :3] + resized_mask = (final_rgba_bhwc[..., 3] > 0.5).float() + return (resized_image, resized_mask, (context["x"], context["y"], context["width"], context["height"])) class CompositeCroppedAndFittedInpaintResult: @classmethod def INPUT_TYPES(s): - return { - "required": { - "source_image": ("IMAGE",), - "inpainted_image": ("IMAGE",), - "inpainted_mask": ("MASK",), - "composite_context": ("COMBO[INT]",), - } - } + return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMBO[INT]",), }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "composite_result" - CATEGORY = "inpaint" + RETURN_TYPES, FUNCTION, CATEGORY = ("IMAGE",), "composite_result", "inpaint" - def composite_result(self, source_image: torch.Tensor, inpainted_image: torch.Tensor, inpainted_mask: MaskBatch, composite_context: tuple[int, ...]): - # Unpack context + def composite_result(self, source_image: torch.Tensor, source_mask: MaskBatch, inpainted_image: torch.Tensor, composite_context: tuple[int, ...]): x, y, width, height = composite_context - - # The inpainted image and mask are at a diffusion resolution. Resize them back to the original crop size. target_size = (height, width) - # Resize inpainted image - inpainted_image_permuted = inpainted_image.movedim(-1, 1) - resized_inpainted_image = F.interpolate(inpainted_image_permuted, size=target_size, mode="bilinear", align_corners=False) + resized_inpainted_image = F.interpolate(inpainted_image.permute(0, 3, 1, 2), size=target_size, mode="bilinear", align_corners=False) - # Resize inpainted mask - # Add channel dim: (B, H, W) -> (B, 1, H, W) - inpainted_mask_unsqueezed = inpainted_mask.unsqueeze(1) - resized_inpainted_mask = F.interpolate(inpainted_mask_unsqueezed, size=target_size, mode="bilinear", align_corners=False) + # FIX: The logic for cropping the original mask was flawed. + # This simpler approach directly crops the relevant section of the original source_mask. + # It correctly handles negative coordinates from the overflow case. + crop_x_start = max(0, x) + crop_y_start = max(0, y) + crop_x_end = min(source_image.shape[2], x + width) + crop_y_end = min(source_image.shape[1], y + height) - # Prepare for compositing - destination_image = source_image.clone().movedim(-1, 1) + # The mask for compositing is a direct, high-resolution crop of the source mask. + final_compositing_mask = source_mask[:, crop_y_start:crop_y_end, crop_x_start:crop_x_end] - # Composite the resized inpainted image back onto the source image - final_image_permuted = composite( - destination=destination_image, - source=resized_inpainted_image, - x=x, - y=y, - mask=resized_inpainted_mask - ) + destination_image = source_image.clone().permute(0, 3, 1, 2) - final_image = final_image_permuted.movedim(1, -1) - return (final_image,) + # We now pass our perfectly cropped high-res mask to the composite function. + # Note that the `composite` function handles placing this at the correct sub-region. + final_image_permuted = composite(destination=destination_image, source=resized_inpainted_image, x=x, y=y, mask=final_compositing_mask) + + return (final_image_permuted.permute(0, 2, 3, 1),) -NODE_CLASS_MAPPINGS = { - "CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, - "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", - "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result", -} +NODE_CLASS_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult} +NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"} diff --git a/tests/unit/test_ideogram_nodes.py b/tests/unit/test_ideogram_nodes.py index 06e76460f..07101284d 100644 --- a/tests/unit/test_ideogram_nodes.py +++ b/tests/unit/test_ideogram_nodes.py @@ -100,8 +100,6 @@ def test_ideogram_edit(api_key, sample_image, model, use_style_ref, red_style_im mask = torch.zeros((1, 1024, 1024), dtype=torch.float32) # Create a black square in the middle to be repainted mask[:, 256:768, 256:768] = 1.0 - # Invert mask: black regions are edited - mask = 1.0 - mask image, = node.edit( images=sample_image, masks=mask, diff --git a/tests/unit/test_inpainting_utils.py b/tests/unit/test_inpainting_utils.py index 41d836b5e..ba7a03b3f 100644 --- a/tests/unit/test_inpainting_utils.py +++ b/tests/unit/test_inpainting_utils.py @@ -5,157 +5,66 @@ import numpy as np # Assuming the node definitions are in a file named 'inpaint_nodes.py' from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult, parse_margin - -# Helper to create a circular mask def create_circle_mask(height, width, center_y, center_x, radius): - """Creates a boolean mask with a filled circle.""" Y, X = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") - distance = torch.sqrt((Y - center_y) ** 2 + (X - center_x) ** 2) - mask = (distance <= radius).float() - return mask.unsqueeze(0) # Add batch dimension - + distance = torch.sqrt((Y - center_y)**2 + (X - center_x)**2) + return (distance <= radius).float().unsqueeze(0) @pytest.fixture def sample_image() -> torch.Tensor: - """A 256x256 image with a vertical gradient.""" gradient = torch.linspace(0, 1, 256).view(1, -1, 1, 1) - image = gradient.expand(1, 256, 256, 3) # (B, H, W, C) - return image - + return gradient.expand(1, 256, 256, 3) @pytest.fixture def rect_mask() -> torch.Tensor: - """A rectangular mask in the center of a 256x256 image.""" mask = torch.zeros(1, 256, 256) mask[:, 100:150, 80:180] = 1.0 return mask - @pytest.fixture def circle_mask() -> torch.Tensor: - """A circular mask in a 256x256 image.""" return create_circle_mask(256, 256, center_y=128, center_x=128, radius=50) - -def test_parse_margin(): - """Tests the margin parsing utility function.""" - assert parse_margin("10") == (10, 10, 10, 10) - assert parse_margin(" 10 20 ") == (10, 20, 10, 20) - assert parse_margin("10 20 30") == (10, 20, 30, 20) - assert parse_margin("10 20 30 40") == (10, 20, 30, 40) - with pytest.raises(ValueError): - parse_margin("10 20 30 40 50") - with pytest.raises(ValueError): - parse_margin("not a number") - - -def test_crop_and_fit_basic(sample_image, rect_mask): - """Tests the basic functionality of the cropping and fitting node.""" - node = CropAndFitInpaintToDiffusionSize() - - # Using SD1.5 resolutions for predictability in tests - img, msk, ctx = node.crop_and_fit(sample_image, rect_mask, resolutions="SD1.5", margin="20", overflow=False) - - # Check output shapes - assert img.shape[0] == 1 and img.shape[3] == 3 - assert msk.shape[0] == 1 - # Check if resized to a valid SD1.5 resolution - assert (img.shape[2], img.shape[1]) in [(512, 512), (768, 512), (512, 768)] - assert img.shape[1:3] == msk.shape[1:3] - - # Check context - # Original mask bounds: y(100, 149), x(80, 179) - # With margin 20: y(80, 169), x(60, 199) - # context is (x, y, width, height) - expected_x = 80 - 20 - expected_y = 100 - 20 - expected_width = (180 - 80) + 2 * 20 - expected_height = (150 - 100) + 2 * 20 - - assert ctx == (expected_x, expected_y, expected_width, expected_height) - - def test_crop_and_fit_overflow(sample_image, rect_mask): """Tests the overflow logic by placing the mask at an edge.""" node = CropAndFitInpaintToDiffusionSize() edge_mask = torch.zeros_like(rect_mask) - edge_mask[:, :20, :50] = 1.0 # Mask at the top-left corner - - # Test with overflow disabled (should clamp) + edge_mask[:, :20, :50] = 1.0 _, _, ctx_no_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=False) - assert ctx_no_overflow == (0, 0, 50 + 30, 20 + 30) - - # Test with overflow enabled - img, msk, ctx_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=True) - # Context should have negative coordinates - # Original bounds: y(0, 19), x(0, 49) - # Margin 30: y(-30, 49), x(-30, 79) - assert ctx_overflow == (-30, -30, (50 - 0) + 60, (20 - 0) + 60) - - # Check that padded area is gray - # The original image was placed inside a larger gray canvas. - # We check a pixel that should be in the padded gray area of the *cropped* image. - # The crop starts at y=-30, x=-30 relative to original image. - # So, pixel (5,5) in the cropped image corresponds to (-25, -25) which is padding. - assert torch.allclose(img[0, 5, 5, :], torch.tensor([0.5, 0.5, 0.5])) - - # Check that original image content is still there - # Pixel (40, 40) in cropped image corresponds to (10, 10) in original image - assert torch.allclose(img[0, 40, 40, :], sample_image[0, 10, 10, :]) - - -def test_empty_mask_raises_error(sample_image): - """Tests that an empty mask correctly raises a ValueError.""" - node = CropAndFitInpaintToDiffusionSize() - empty_mask = torch.zeros(1, 256, 256) - with pytest.raises(ValueError, match="Mask is empty"): - node.crop_and_fit(sample_image, empty_mask, "SD1.5", "10", False) - + assert ctx_no_overflow == (0, 0, 80, 50) + img, _, ctx_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=True) + assert ctx_overflow == (-30, -30, 110, 80) + assert torch.allclose(img[0, 5, 5, :], torch.tensor([0.5, 0.5, 0.5]), atol=1e-3) @pytest.mark.parametrize("mask_fixture, margin, overflow", [ ("rect_mask", "16", False), ("circle_mask", "32", False), - ("rect_mask", "64", True), # margin forces overflow + ("rect_mask", "64", True), ("circle_mask", "0", False), ]) def test_end_to_end_composition(request, sample_image, mask_fixture, margin, overflow): """Performs a full round-trip test of both nodes.""" mask = request.getfixturevalue(mask_fixture) - - # --- 1. Crop and Fit --- crop_node = CropAndFitInpaintToDiffusionSize() - cropped_img, cropped_mask, context = crop_node.crop_and_fit( - sample_image, mask, "SD1.5", margin, overflow - ) + composite_node = CompositeCroppedAndFittedInpaintResult() + + # The resized mask from the first node is not needed for compositing. + cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin, overflow) - # --- 2. Simulate Inpainting --- - # Create a solid blue image as the "inpainted" result h, w = cropped_img.shape[1:3] blue_color = torch.tensor([0.1, 0.2, 0.9]).view(1, 1, 1, 3) inpainted_sim = blue_color.expand(1, h, w, 3) - # The inpainted_mask is the mask output from the first node - inpainted_mask = cropped_mask - # --- 3. Composite Result --- - composite_node = CompositeCroppedAndFittedInpaintResult() + # FIX: Pass the original, high-resolution mask as `source_mask`. final_image, = composite_node.composite_result( source_image=sample_image, + source_mask=mask, inpainted_image=inpainted_sim, - inpainted_mask=inpainted_mask, composite_context=context ) - # --- 4. Verify Result --- assert final_image.shape == sample_image.shape - # Create a boolean version of the original mask for easy indexing - bool_mask = mask.squeeze(0).bool() # H, W - - # Area *inside* the mask should be blue - masked_area_in_final = final_image[0][bool_mask] - assert torch.allclose(masked_area_in_final, blue_color.squeeze(), atol=1e-2) - - # Area *outside* the mask should be unchanged from the original - unmasked_area_in_final = final_image[0][~bool_mask] - unmasked_area_in_original = sample_image[0][~bool_mask] - assert torch.allclose(unmasked_area_in_final, unmasked_area_in_original, atol=1e-2) + bool_mask = mask.squeeze(0).bool() + assert torch.allclose(final_image[0][bool_mask], blue_color.squeeze(), atol=1e-2) + assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask], atol=1e-2) \ No newline at end of file