mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
wip inpainting fixes and ideogram now takes a mask that is more convention from the POV of comfyui
This commit is contained in:
parent
285b9485f4
commit
396a2ef3d3
@ -163,7 +163,7 @@ class IdeogramEdit(CustomNode):
|
||||
headers = {"Api-Key": api_key}
|
||||
image_responses = []
|
||||
for mask_tensor, image_tensor in zip(torch.unbind(masks), torch.unbind(images)):
|
||||
mask_tensor, = MaskToImage().mask_to_image(mask=mask_tensor)
|
||||
mask_tensor, = MaskToImage().mask_to_image(mask=1. - mask_tensor)
|
||||
|
||||
image_pil, mask_pil = tensor2pil(image_tensor), tensor2pil(mask_tensor)
|
||||
image_bytes, mask_bytes = BytesIO(), BytesIO()
|
||||
|
||||
@ -1,244 +1,151 @@
|
||||
import torch
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.component_model.tensor_types import MaskBatch
|
||||
from comfy_extras.constants.resolutions import RESOLUTION_NAMES
|
||||
from comfy_extras.nodes.nodes_images import ImageResize
|
||||
from comfy_extras.constants.resolutions import (
|
||||
RESOLUTION_NAMES, SDXL_SD3_FLUX_RESOLUTIONS, SD_RESOLUTIONS, LTVX_RESOLUTIONS,
|
||||
IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS,
|
||||
WAN_VIDEO_14B_RESOLUTIONS, WAN_VIDEO_1_3B_RESOLUTIONS,
|
||||
WAN_VIDEO_14B_EXTENDED_RESOLUTIONS
|
||||
)
|
||||
|
||||
|
||||
# Helper function from the context to composite images
|
||||
def composite(destination, source, x, y, mask=None, multiplier=1, resize_source=False):
|
||||
# This function is adapted from the provided context code
|
||||
source = source.to(destination.device)
|
||||
if resize_source:
|
||||
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
||||
|
||||
# Ensure source has the same batch size as destination
|
||||
source = F.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
||||
if source.shape[0] != destination.shape[0]:
|
||||
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
|
||||
|
||||
x = int(x)
|
||||
y = int(y)
|
||||
|
||||
left, top = (x, y)
|
||||
right, bottom = (left + source.shape[3], top + source.shape[2])
|
||||
x, y = int(x), int(y)
|
||||
left, top = x, y
|
||||
right, bottom = left + source.shape[3], top + source.shape[2]
|
||||
|
||||
if mask is None:
|
||||
# If no mask is provided, create a full-coverage mask
|
||||
mask = torch.ones_like(source)
|
||||
else:
|
||||
# Ensure mask is on the correct device and is the correct size
|
||||
mask = mask.to(destination.device, copy=True)
|
||||
# Check if the mask is 2D (H, W) or 3D (B, H, W) and unsqueeze if necessary
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(1) # Add channel dimension
|
||||
mask = torch.nn.functional.interpolate(mask, size=(source.shape[2], source.shape[3]), mode="bilinear")
|
||||
if mask.dim() == 2: mask = mask.unsqueeze(0)
|
||||
if mask.dim() == 3: mask = mask.unsqueeze(1)
|
||||
if mask.shape[0] != source.shape[0]:
|
||||
mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1)
|
||||
|
||||
# Define the bounds of the overlapping area
|
||||
dest_left = max(0, left)
|
||||
dest_top = max(0, top)
|
||||
dest_right = min(destination.shape[3], right)
|
||||
dest_bottom = min(destination.shape[2], bottom)
|
||||
dest_left, dest_top = max(0, left), max(0, top)
|
||||
dest_right, dest_bottom = min(destination.shape[3], right), min(destination.shape[2], bottom)
|
||||
|
||||
# If there is no overlap, return the original destination
|
||||
if dest_right <= dest_left or dest_bottom <= dest_top:
|
||||
return destination
|
||||
if dest_right <= dest_left or dest_bottom <= dest_top: return destination
|
||||
|
||||
# Calculate the source coordinates corresponding to the overlap
|
||||
src_left = dest_left - left
|
||||
src_top = dest_top - top
|
||||
src_right = dest_right - left
|
||||
src_bottom = dest_bottom - top
|
||||
src_left, src_top = dest_left - left, dest_top - top
|
||||
src_right, src_bottom = dest_right - left, dest_bottom
|
||||
|
||||
# Crop the relevant portions of the destination, source, and mask
|
||||
destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right]
|
||||
source_portion = source[:, :, src_top:src_bottom, src_left:src_right]
|
||||
mask_portion = mask[:, :, src_top:src_bottom, src_left:src_right]
|
||||
|
||||
inverse_mask_portion = 1.0 - mask_portion
|
||||
|
||||
# Perform the composition
|
||||
blended_portion = (source_portion * mask_portion) + (destination_portion * inverse_mask_portion)
|
||||
|
||||
# Place the blended portion back into the destination
|
||||
blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion))
|
||||
destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion
|
||||
|
||||
return destination
|
||||
|
||||
|
||||
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
|
||||
"""Parses a CSS-style margin string."""
|
||||
parts = [int(p) for p in margin_str.strip().split()]
|
||||
if len(parts) == 1:
|
||||
return parts[0], parts[0], parts[0], parts[0]
|
||||
if len(parts) == 2:
|
||||
return parts[0], parts[1], parts[0], parts[1]
|
||||
if len(parts) == 3:
|
||||
return parts[0], parts[1], parts[2], parts[1]
|
||||
if len(parts) == 4:
|
||||
return parts[0], parts[1], parts[2], parts[3]
|
||||
raise ValueError("Invalid margin format. Use 1 to 4 integer values.")
|
||||
if len(parts) == 1: return parts[0], parts[0], parts[0], parts[0]
|
||||
if len(parts) == 2: return parts[0], parts[1], parts[0], parts[1]
|
||||
if len(parts) == 3: return parts[0], parts[1], parts[2], parts[1]
|
||||
if len(parts) == 4: return parts[0], parts[1], parts[2], parts[3]
|
||||
raise ValueError("Invalid margin format.")
|
||||
|
||||
|
||||
class CropAndFitInpaintToDiffusionSize:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"mask": ("MASK",),
|
||||
"resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}),
|
||||
"margin": ("STRING", {"default": "64"}),
|
||||
"overflow": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
return {"required": {"image": ("IMAGE",), "mask": ("MASK",), "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), "margin": ("STRING", {"default": "64"}), "overflow": ("BOOLEAN", {"default": True}), }}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "COMBO[INT]")
|
||||
RETURN_NAMES = ("image", "mask", "composite_context")
|
||||
FUNCTION = "crop_and_fit"
|
||||
CATEGORY = "inpaint"
|
||||
RETURN_TYPES, RETURN_NAMES, FUNCTION, CATEGORY = ("IMAGE", "MASK", "COMBO[INT]"), ("image", "mask", "composite_context"), "crop_and_fit", "inpaint"
|
||||
|
||||
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, overflow: bool):
|
||||
# 1. Find bounding box of the mask
|
||||
if mask.max() <= 0:
|
||||
raise ValueError("Mask is empty, cannot determine bounding box.")
|
||||
|
||||
# Find the coordinates of non-zero mask pixels
|
||||
mask_coords = torch.nonzero(mask[0]) # Assuming single batch for mask
|
||||
if mask_coords.numel() == 0:
|
||||
raise ValueError("Mask is empty, cannot determine bounding box.")
|
||||
|
||||
y_min, x_min = mask_coords.min(dim=0).values
|
||||
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, overflow: bool, aspect_ratio_tolerance=0.05):
|
||||
if mask.max() <= 0: raise ValueError("Mask is empty.")
|
||||
mask_coords = torch.nonzero(mask[0]);
|
||||
if mask_coords.numel() == 0: raise ValueError("Mask is empty.")
|
||||
y_min, x_min = mask_coords.min(dim=0).values;
|
||||
y_max, x_max = mask_coords.max(dim=0).values
|
||||
top_m, right_m, bottom_m, left_m = parse_margin(margin)
|
||||
x_start_init, y_start_init = x_min.item() - left_m, y_min.item() - top_m
|
||||
x_end_init, y_end_init = x_max.item() + 1 + right_m, y_max.item() + 1 + bottom_m
|
||||
img_h, img_w = image.shape[1:3]
|
||||
pad_image, pad_mask = image, mask
|
||||
x_start_crop, y_start_crop = x_start_init, y_start_init
|
||||
x_end_crop, y_end_crop = x_end_init, y_end_init
|
||||
pad_l, pad_t = -min(0, x_start_init), -min(0, y_start_init)
|
||||
pad_r, pad_b = max(0, x_end_init - img_w), max(0, y_end_init - img_h)
|
||||
if any([pad_l, pad_t, pad_r, pad_b]) and overflow:
|
||||
padding = (pad_l, pad_r, pad_t, pad_b)
|
||||
pad_image = F.pad(image.permute(0, 3, 1, 2), padding, "constant", 0.5).permute(0, 2, 3, 1)
|
||||
pad_mask = F.pad(mask.unsqueeze(1), padding, "constant", 0).squeeze(1)
|
||||
x_start_crop += pad_l;
|
||||
y_start_crop += pad_t;
|
||||
x_end_crop += pad_l;
|
||||
y_end_crop += pad_t
|
||||
else:
|
||||
x_start_crop, y_start_crop = max(0, x_start_init), max(0, y_start_init)
|
||||
x_end_crop, y_end_crop = min(img_w, x_end_init), min(img_h, y_end_init)
|
||||
composite_x, composite_y = (x_start_init if overflow else x_start_crop), (y_start_init if overflow else y_start_crop)
|
||||
cropped_image = pad_image[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop, :]
|
||||
cropped_mask = pad_mask[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop]
|
||||
context = {"x": composite_x, "y": composite_y, "width": cropped_image.shape[2], "height": cropped_image.shape[1]}
|
||||
|
||||
# 2. Parse and apply margin
|
||||
top_margin, right_margin, bottom_margin, left_margin = parse_margin(margin)
|
||||
|
||||
x_start = x_min.item() - left_margin
|
||||
y_start = y_min.item() - top_margin
|
||||
x_end = x_max.item() + 1 + right_margin
|
||||
y_end = y_max.item() + 1 + bottom_margin
|
||||
|
||||
img_height, img_width = image.shape[1:3]
|
||||
|
||||
# Store pre-crop context for the compositor node
|
||||
context = {
|
||||
"x": x_start,
|
||||
"y": y_start,
|
||||
"width": x_end - x_start,
|
||||
"height": y_end - y_start
|
||||
}
|
||||
|
||||
# 3. Handle overflow
|
||||
padded_image = image
|
||||
padded_mask = mask
|
||||
|
||||
pad_left = -min(0, x_start)
|
||||
pad_top = -min(0, y_start)
|
||||
pad_right = max(0, x_end - img_width)
|
||||
pad_bottom = max(0, y_end - img_height)
|
||||
|
||||
if any([pad_left, pad_top, pad_right, pad_bottom]):
|
||||
if not overflow:
|
||||
# Crop margin to fit within the image
|
||||
x_start = max(0, x_start)
|
||||
y_start = max(0, y_start)
|
||||
x_end = min(img_width, x_end)
|
||||
y_end = min(img_height, y_end)
|
||||
else:
|
||||
# Extend image and mask
|
||||
padding = (pad_left, pad_right, pad_top, pad_bottom)
|
||||
# Pad image with gray
|
||||
padded_image = F.pad(image.permute(0, 3, 1, 2), padding, "constant", 0.5).permute(0, 2, 3, 1)
|
||||
# Pad mask with zeros
|
||||
padded_mask = F.pad(mask.unsqueeze(1), padding, "constant", 0).squeeze(1)
|
||||
|
||||
# Adjust coordinates for the new padded space
|
||||
x_start += pad_left
|
||||
y_start += pad_top
|
||||
x_end += pad_left
|
||||
y_end += pad_top
|
||||
|
||||
# 4. Crop image and mask
|
||||
cropped_image = padded_image[:, y_start:y_end, x_start:x_end, :]
|
||||
cropped_mask = padded_mask[:, y_start:y_end, x_start:x_end]
|
||||
|
||||
# 5. Resize to a supported resolution
|
||||
resizer = ImageResize()
|
||||
resized_image, = resizer.resize_image(cropped_image, "cover", resolutions, "lanczos")
|
||||
|
||||
# Resize mask similarly. Convert to image-like tensor for resizing.
|
||||
cropped_mask_as_image = cropped_mask.unsqueeze(-1).repeat(1, 1, 1, 3)
|
||||
resized_mask_as_image, = resizer.resize_image(cropped_mask_as_image, "cover", resolutions, "lanczos")
|
||||
# Convert back to a mask (using the red channel)
|
||||
resized_mask = resized_mask_as_image[:, :, :, 0]
|
||||
|
||||
# Pack context into a list of ints for output
|
||||
# Format: [x, y, width, height]
|
||||
composite_context = (context["x"], context["y"], context["width"], context["height"])
|
||||
|
||||
return (resized_image, resized_mask, composite_context)
|
||||
rgba_bchw = torch.cat((cropped_image.permute(0, 3, 1, 2), cropped_mask.unsqueeze(1)), dim=1)
|
||||
res_map = {"SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS}
|
||||
supported_resolutions = res_map.get(resolutions, SD_RESOLUTIONS)
|
||||
h, w = cropped_image.shape[1:3]
|
||||
current_aspect_ratio = w / h
|
||||
diffs = [(abs(res[0] / res[1] - current_aspect_ratio), res) for res in supported_resolutions]
|
||||
min_diff = min(diffs, key=lambda x: x[0])[0]
|
||||
close_res = [res for diff, res in diffs if diff <= min_diff + aspect_ratio_tolerance]
|
||||
target_res = max(close_res, key=lambda r: r[0] * r[1])
|
||||
scale = max(target_res[0] / w, target_res[1] / h)
|
||||
new_w, new_h = int(w * scale), int(h * scale)
|
||||
upscaled_rgba = F.interpolate(rgba_bchw, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
||||
y1, x1 = (new_h - target_res[1]) // 2, (new_w - target_res[0]) // 2
|
||||
final_rgba_bchw = upscaled_rgba[:, :, y1:y1 + target_res[1], x1:x1 + target_res[0]]
|
||||
final_rgba_bhwc = final_rgba_bchw.permute(0, 2, 3, 1)
|
||||
resized_image = final_rgba_bhwc[..., :3]
|
||||
resized_mask = (final_rgba_bhwc[..., 3] > 0.5).float()
|
||||
return (resized_image, resized_mask, (context["x"], context["y"], context["width"], context["height"]))
|
||||
|
||||
|
||||
class CompositeCroppedAndFittedInpaintResult:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"source_image": ("IMAGE",),
|
||||
"inpainted_image": ("IMAGE",),
|
||||
"inpainted_mask": ("MASK",),
|
||||
"composite_context": ("COMBO[INT]",),
|
||||
}
|
||||
}
|
||||
return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMBO[INT]",), }}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "composite_result"
|
||||
CATEGORY = "inpaint"
|
||||
RETURN_TYPES, FUNCTION, CATEGORY = ("IMAGE",), "composite_result", "inpaint"
|
||||
|
||||
def composite_result(self, source_image: torch.Tensor, inpainted_image: torch.Tensor, inpainted_mask: MaskBatch, composite_context: tuple[int, ...]):
|
||||
# Unpack context
|
||||
def composite_result(self, source_image: torch.Tensor, source_mask: MaskBatch, inpainted_image: torch.Tensor, composite_context: tuple[int, ...]):
|
||||
x, y, width, height = composite_context
|
||||
|
||||
# The inpainted image and mask are at a diffusion resolution. Resize them back to the original crop size.
|
||||
target_size = (height, width)
|
||||
|
||||
# Resize inpainted image
|
||||
inpainted_image_permuted = inpainted_image.movedim(-1, 1)
|
||||
resized_inpainted_image = F.interpolate(inpainted_image_permuted, size=target_size, mode="bilinear", align_corners=False)
|
||||
resized_inpainted_image = F.interpolate(inpainted_image.permute(0, 3, 1, 2), size=target_size, mode="bilinear", align_corners=False)
|
||||
|
||||
# Resize inpainted mask
|
||||
# Add channel dim: (B, H, W) -> (B, 1, H, W)
|
||||
inpainted_mask_unsqueezed = inpainted_mask.unsqueeze(1)
|
||||
resized_inpainted_mask = F.interpolate(inpainted_mask_unsqueezed, size=target_size, mode="bilinear", align_corners=False)
|
||||
# FIX: The logic for cropping the original mask was flawed.
|
||||
# This simpler approach directly crops the relevant section of the original source_mask.
|
||||
# It correctly handles negative coordinates from the overflow case.
|
||||
crop_x_start = max(0, x)
|
||||
crop_y_start = max(0, y)
|
||||
crop_x_end = min(source_image.shape[2], x + width)
|
||||
crop_y_end = min(source_image.shape[1], y + height)
|
||||
|
||||
# Prepare for compositing
|
||||
destination_image = source_image.clone().movedim(-1, 1)
|
||||
# The mask for compositing is a direct, high-resolution crop of the source mask.
|
||||
final_compositing_mask = source_mask[:, crop_y_start:crop_y_end, crop_x_start:crop_x_end]
|
||||
|
||||
# Composite the resized inpainted image back onto the source image
|
||||
final_image_permuted = composite(
|
||||
destination=destination_image,
|
||||
source=resized_inpainted_image,
|
||||
x=x,
|
||||
y=y,
|
||||
mask=resized_inpainted_mask
|
||||
)
|
||||
destination_image = source_image.clone().permute(0, 3, 1, 2)
|
||||
|
||||
final_image = final_image_permuted.movedim(1, -1)
|
||||
return (final_image,)
|
||||
# We now pass our perfectly cropped high-res mask to the composite function.
|
||||
# Note that the `composite` function handles placing this at the correct sub-region.
|
||||
final_image_permuted = composite(destination=destination_image, source=resized_inpainted_image, x=x, y=y, mask=final_compositing_mask)
|
||||
|
||||
return (final_image_permuted.permute(0, 2, 3, 1),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize,
|
||||
"CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region",
|
||||
"CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result",
|
||||
}
|
||||
NODE_CLASS_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"}
|
||||
|
||||
@ -100,8 +100,6 @@ def test_ideogram_edit(api_key, sample_image, model, use_style_ref, red_style_im
|
||||
mask = torch.zeros((1, 1024, 1024), dtype=torch.float32)
|
||||
# Create a black square in the middle to be repainted
|
||||
mask[:, 256:768, 256:768] = 1.0
|
||||
# Invert mask: black regions are edited
|
||||
mask = 1.0 - mask
|
||||
|
||||
image, = node.edit(
|
||||
images=sample_image, masks=mask,
|
||||
|
||||
@ -5,157 +5,66 @@ import numpy as np
|
||||
# Assuming the node definitions are in a file named 'inpaint_nodes.py'
|
||||
from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult, parse_margin
|
||||
|
||||
|
||||
# Helper to create a circular mask
|
||||
def create_circle_mask(height, width, center_y, center_x, radius):
|
||||
"""Creates a boolean mask with a filled circle."""
|
||||
Y, X = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
|
||||
distance = torch.sqrt((Y - center_y) ** 2 + (X - center_x) ** 2)
|
||||
mask = (distance <= radius).float()
|
||||
return mask.unsqueeze(0) # Add batch dimension
|
||||
|
||||
distance = torch.sqrt((Y - center_y)**2 + (X - center_x)**2)
|
||||
return (distance <= radius).float().unsqueeze(0)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image() -> torch.Tensor:
|
||||
"""A 256x256 image with a vertical gradient."""
|
||||
gradient = torch.linspace(0, 1, 256).view(1, -1, 1, 1)
|
||||
image = gradient.expand(1, 256, 256, 3) # (B, H, W, C)
|
||||
return image
|
||||
|
||||
return gradient.expand(1, 256, 256, 3)
|
||||
|
||||
@pytest.fixture
|
||||
def rect_mask() -> torch.Tensor:
|
||||
"""A rectangular mask in the center of a 256x256 image."""
|
||||
mask = torch.zeros(1, 256, 256)
|
||||
mask[:, 100:150, 80:180] = 1.0
|
||||
return mask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def circle_mask() -> torch.Tensor:
|
||||
"""A circular mask in a 256x256 image."""
|
||||
return create_circle_mask(256, 256, center_y=128, center_x=128, radius=50)
|
||||
|
||||
|
||||
def test_parse_margin():
|
||||
"""Tests the margin parsing utility function."""
|
||||
assert parse_margin("10") == (10, 10, 10, 10)
|
||||
assert parse_margin(" 10 20 ") == (10, 20, 10, 20)
|
||||
assert parse_margin("10 20 30") == (10, 20, 30, 20)
|
||||
assert parse_margin("10 20 30 40") == (10, 20, 30, 40)
|
||||
with pytest.raises(ValueError):
|
||||
parse_margin("10 20 30 40 50")
|
||||
with pytest.raises(ValueError):
|
||||
parse_margin("not a number")
|
||||
|
||||
|
||||
def test_crop_and_fit_basic(sample_image, rect_mask):
|
||||
"""Tests the basic functionality of the cropping and fitting node."""
|
||||
node = CropAndFitInpaintToDiffusionSize()
|
||||
|
||||
# Using SD1.5 resolutions for predictability in tests
|
||||
img, msk, ctx = node.crop_and_fit(sample_image, rect_mask, resolutions="SD1.5", margin="20", overflow=False)
|
||||
|
||||
# Check output shapes
|
||||
assert img.shape[0] == 1 and img.shape[3] == 3
|
||||
assert msk.shape[0] == 1
|
||||
# Check if resized to a valid SD1.5 resolution
|
||||
assert (img.shape[2], img.shape[1]) in [(512, 512), (768, 512), (512, 768)]
|
||||
assert img.shape[1:3] == msk.shape[1:3]
|
||||
|
||||
# Check context
|
||||
# Original mask bounds: y(100, 149), x(80, 179)
|
||||
# With margin 20: y(80, 169), x(60, 199)
|
||||
# context is (x, y, width, height)
|
||||
expected_x = 80 - 20
|
||||
expected_y = 100 - 20
|
||||
expected_width = (180 - 80) + 2 * 20
|
||||
expected_height = (150 - 100) + 2 * 20
|
||||
|
||||
assert ctx == (expected_x, expected_y, expected_width, expected_height)
|
||||
|
||||
|
||||
def test_crop_and_fit_overflow(sample_image, rect_mask):
|
||||
"""Tests the overflow logic by placing the mask at an edge."""
|
||||
node = CropAndFitInpaintToDiffusionSize()
|
||||
edge_mask = torch.zeros_like(rect_mask)
|
||||
edge_mask[:, :20, :50] = 1.0 # Mask at the top-left corner
|
||||
|
||||
# Test with overflow disabled (should clamp)
|
||||
edge_mask[:, :20, :50] = 1.0
|
||||
_, _, ctx_no_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=False)
|
||||
assert ctx_no_overflow == (0, 0, 50 + 30, 20 + 30)
|
||||
|
||||
# Test with overflow enabled
|
||||
img, msk, ctx_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=True)
|
||||
# Context should have negative coordinates
|
||||
# Original bounds: y(0, 19), x(0, 49)
|
||||
# Margin 30: y(-30, 49), x(-30, 79)
|
||||
assert ctx_overflow == (-30, -30, (50 - 0) + 60, (20 - 0) + 60)
|
||||
|
||||
# Check that padded area is gray
|
||||
# The original image was placed inside a larger gray canvas.
|
||||
# We check a pixel that should be in the padded gray area of the *cropped* image.
|
||||
# The crop starts at y=-30, x=-30 relative to original image.
|
||||
# So, pixel (5,5) in the cropped image corresponds to (-25, -25) which is padding.
|
||||
assert torch.allclose(img[0, 5, 5, :], torch.tensor([0.5, 0.5, 0.5]))
|
||||
|
||||
# Check that original image content is still there
|
||||
# Pixel (40, 40) in cropped image corresponds to (10, 10) in original image
|
||||
assert torch.allclose(img[0, 40, 40, :], sample_image[0, 10, 10, :])
|
||||
|
||||
|
||||
def test_empty_mask_raises_error(sample_image):
|
||||
"""Tests that an empty mask correctly raises a ValueError."""
|
||||
node = CropAndFitInpaintToDiffusionSize()
|
||||
empty_mask = torch.zeros(1, 256, 256)
|
||||
with pytest.raises(ValueError, match="Mask is empty"):
|
||||
node.crop_and_fit(sample_image, empty_mask, "SD1.5", "10", False)
|
||||
|
||||
assert ctx_no_overflow == (0, 0, 80, 50)
|
||||
img, _, ctx_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=True)
|
||||
assert ctx_overflow == (-30, -30, 110, 80)
|
||||
assert torch.allclose(img[0, 5, 5, :], torch.tensor([0.5, 0.5, 0.5]), atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("mask_fixture, margin, overflow", [
|
||||
("rect_mask", "16", False),
|
||||
("circle_mask", "32", False),
|
||||
("rect_mask", "64", True), # margin forces overflow
|
||||
("rect_mask", "64", True),
|
||||
("circle_mask", "0", False),
|
||||
])
|
||||
def test_end_to_end_composition(request, sample_image, mask_fixture, margin, overflow):
|
||||
"""Performs a full round-trip test of both nodes."""
|
||||
mask = request.getfixturevalue(mask_fixture)
|
||||
|
||||
# --- 1. Crop and Fit ---
|
||||
crop_node = CropAndFitInpaintToDiffusionSize()
|
||||
cropped_img, cropped_mask, context = crop_node.crop_and_fit(
|
||||
sample_image, mask, "SD1.5", margin, overflow
|
||||
)
|
||||
composite_node = CompositeCroppedAndFittedInpaintResult()
|
||||
|
||||
# The resized mask from the first node is not needed for compositing.
|
||||
cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin, overflow)
|
||||
|
||||
# --- 2. Simulate Inpainting ---
|
||||
# Create a solid blue image as the "inpainted" result
|
||||
h, w = cropped_img.shape[1:3]
|
||||
blue_color = torch.tensor([0.1, 0.2, 0.9]).view(1, 1, 1, 3)
|
||||
inpainted_sim = blue_color.expand(1, h, w, 3)
|
||||
# The inpainted_mask is the mask output from the first node
|
||||
inpainted_mask = cropped_mask
|
||||
|
||||
# --- 3. Composite Result ---
|
||||
composite_node = CompositeCroppedAndFittedInpaintResult()
|
||||
# FIX: Pass the original, high-resolution mask as `source_mask`.
|
||||
final_image, = composite_node.composite_result(
|
||||
source_image=sample_image,
|
||||
source_mask=mask,
|
||||
inpainted_image=inpainted_sim,
|
||||
inpainted_mask=inpainted_mask,
|
||||
composite_context=context
|
||||
)
|
||||
|
||||
# --- 4. Verify Result ---
|
||||
assert final_image.shape == sample_image.shape
|
||||
|
||||
# Create a boolean version of the original mask for easy indexing
|
||||
bool_mask = mask.squeeze(0).bool() # H, W
|
||||
|
||||
# Area *inside* the mask should be blue
|
||||
masked_area_in_final = final_image[0][bool_mask]
|
||||
assert torch.allclose(masked_area_in_final, blue_color.squeeze(), atol=1e-2)
|
||||
|
||||
# Area *outside* the mask should be unchanged from the original
|
||||
unmasked_area_in_final = final_image[0][~bool_mask]
|
||||
unmasked_area_in_original = sample_image[0][~bool_mask]
|
||||
assert torch.allclose(unmasked_area_in_final, unmasked_area_in_original, atol=1e-2)
|
||||
bool_mask = mask.squeeze(0).bool()
|
||||
assert torch.allclose(final_image[0][bool_mask], blue_color.squeeze(), atol=1e-2)
|
||||
assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask], atol=1e-2)
|
||||
Loading…
Reference in New Issue
Block a user