mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
inpainting nodes
This commit is contained in:
parent
396a2ef3d3
commit
d4c9d5c748
@ -10,7 +10,7 @@ from PIL.PngImagePlugin import PngInfo
|
||||
from comfy import utils
|
||||
from comfy.cli_args import args
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch
|
||||
from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch, RGBAImageBatch
|
||||
from comfy.nodes.base_nodes import ImageScale
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
@ -286,7 +286,7 @@ class ImageResize:
|
||||
FUNCTION = "resize_image"
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def resize_image(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]:
|
||||
def resize_image(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]:
|
||||
if resolutions == "SDXL/SD3/Flux":
|
||||
supported_resolutions = SDXL_SD3_FLUX_RESOLUTIONS
|
||||
elif resolutions == "LTXV":
|
||||
@ -307,7 +307,7 @@ class ImageResize:
|
||||
supported_resolutions = SD_RESOLUTIONS
|
||||
return self.resize_image_with_supported_resolutions(image, resize_mode, supported_resolutions, interpolation, aspect_ratio_tolerance=aspect_ratio_tolerance)
|
||||
|
||||
def resize_image_with_supported_resolutions(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], supported_resolutions: list[tuple[int, int]], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]:
|
||||
def resize_image_with_supported_resolutions(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], supported_resolutions: list[tuple[int, int]], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]:
|
||||
resized_images = []
|
||||
for img in image:
|
||||
h, w = img.shape[:2]
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from comfy.component_model.tensor_types import MaskBatch
|
||||
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
from comfy_extras.constants.resolutions import (
|
||||
RESOLUTION_NAMES, SDXL_SD3_FLUX_RESOLUTIONS, SD_RESOLUTIONS, LTVX_RESOLUTIONS,
|
||||
IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS,
|
||||
@ -9,11 +11,15 @@ from comfy_extras.constants.resolutions import (
|
||||
WAN_VIDEO_14B_EXTENDED_RESOLUTIONS
|
||||
)
|
||||
|
||||
class CompositeContext(NamedTuple):
|
||||
x: int
|
||||
y: int
|
||||
width: int
|
||||
height: int
|
||||
|
||||
def composite(destination, source, x, y, mask=None, multiplier=1, resize_source=False):
|
||||
def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None):
|
||||
"""A robust function to composite a source tensor onto a destination tensor."""
|
||||
source = source.to(destination.device)
|
||||
if resize_source:
|
||||
source = F.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
||||
if source.shape[0] != destination.shape[0]:
|
||||
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
|
||||
|
||||
@ -40,13 +46,14 @@ def composite(destination, source, x, y, mask=None, multiplier=1, resize_source=
|
||||
|
||||
destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right]
|
||||
source_portion = source[:, :, src_top:src_bottom, src_left:src_right]
|
||||
mask_portion = mask[:, :, src_top:src_bottom, src_left:src_right]
|
||||
|
||||
# The mask must be cropped to the region of interest on the destination.
|
||||
mask_portion = mask[:, :, dest_top:dest_bottom, dest_left:dest_right]
|
||||
|
||||
blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion))
|
||||
destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion
|
||||
return destination
|
||||
|
||||
|
||||
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
|
||||
parts = [int(p) for p in margin_str.strip().split()]
|
||||
if len(parts) == 1: return parts[0], parts[0], parts[0], parts[0]
|
||||
@ -55,97 +62,93 @@ def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
|
||||
if len(parts) == 4: return parts[0], parts[1], parts[2], parts[3]
|
||||
raise ValueError("Invalid margin format.")
|
||||
|
||||
|
||||
class CropAndFitInpaintToDiffusionSize:
|
||||
class CropAndFitInpaintToDiffusionSize(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {"image": ("IMAGE",), "mask": ("MASK",), "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), "margin": ("STRING", {"default": "64"}), "overflow": ("BOOLEAN", {"default": True}), }}
|
||||
return {"required": {
|
||||
"image": ("IMAGE",), "mask": ("MASK",),
|
||||
"resolutions": (RESOLUTION_NAMES, {"default": "SD1.5"}),
|
||||
"margin": ("STRING", {"default": "64"}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES, RETURN_NAMES, FUNCTION, CATEGORY = ("IMAGE", "MASK", "COMBO[INT]"), ("image", "mask", "composite_context"), "crop_and_fit", "inpaint"
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "COMPOSITE_CONTEXT")
|
||||
RETURN_NAMES = ("image", "mask", "composite_context")
|
||||
FUNCTION = "crop_and_fit"
|
||||
CATEGORY = "inpaint"
|
||||
|
||||
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, overflow: bool, aspect_ratio_tolerance=0.05):
|
||||
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, aspect_ratio_tolerance=0.05):
|
||||
if mask.max() <= 0: raise ValueError("Mask is empty.")
|
||||
mask_coords = torch.nonzero(mask[0]);
|
||||
mask_coords = torch.nonzero(mask)
|
||||
if mask_coords.numel() == 0: raise ValueError("Mask is empty.")
|
||||
y_min, x_min = mask_coords.min(dim=0).values;
|
||||
y_max, x_max = mask_coords.max(dim=0).values
|
||||
|
||||
y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2]
|
||||
y_min, x_min = y_coords.min().item(), x_coords.min().item()
|
||||
y_max, x_max = y_coords.max().item(), x_coords.max().item()
|
||||
|
||||
top_m, right_m, bottom_m, left_m = parse_margin(margin)
|
||||
x_start_init, y_start_init = x_min.item() - left_m, y_min.item() - top_m
|
||||
x_end_init, y_end_init = x_max.item() + 1 + right_m, y_max.item() + 1 + bottom_m
|
||||
x_start_expanded, y_start_expanded = x_min - left_m, y_min - top_m
|
||||
x_end_expanded, y_end_expanded = x_max + 1 + right_m, y_max + 1 + bottom_m
|
||||
|
||||
img_h, img_w = image.shape[1:3]
|
||||
pad_image, pad_mask = image, mask
|
||||
x_start_crop, y_start_crop = x_start_init, y_start_init
|
||||
x_end_crop, y_end_crop = x_end_init, y_end_init
|
||||
pad_l, pad_t = -min(0, x_start_init), -min(0, y_start_init)
|
||||
pad_r, pad_b = max(0, x_end_init - img_w), max(0, y_end_init - img_h)
|
||||
if any([pad_l, pad_t, pad_r, pad_b]) and overflow:
|
||||
padding = (pad_l, pad_r, pad_t, pad_b)
|
||||
pad_image = F.pad(image.permute(0, 3, 1, 2), padding, "constant", 0.5).permute(0, 2, 3, 1)
|
||||
pad_mask = F.pad(mask.unsqueeze(1), padding, "constant", 0).squeeze(1)
|
||||
x_start_crop += pad_l;
|
||||
y_start_crop += pad_t;
|
||||
x_end_crop += pad_l;
|
||||
y_end_crop += pad_t
|
||||
else:
|
||||
x_start_crop, y_start_crop = max(0, x_start_init), max(0, y_start_init)
|
||||
x_end_crop, y_end_crop = min(img_w, x_end_init), min(img_h, y_end_init)
|
||||
composite_x, composite_y = (x_start_init if overflow else x_start_crop), (y_start_init if overflow else y_start_crop)
|
||||
cropped_image = pad_image[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop, :]
|
||||
cropped_mask = pad_mask[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop]
|
||||
context = {"x": composite_x, "y": composite_y, "width": cropped_image.shape[2], "height": cropped_image.shape[1]}
|
||||
clamped_x_start, clamped_y_start = max(0, x_start_expanded), max(0, y_start_expanded)
|
||||
clamped_x_end, clamped_y_end = min(img_w, x_end_expanded), min(img_h, y_end_expanded)
|
||||
|
||||
rgba_bchw = torch.cat((cropped_image.permute(0, 3, 1, 2), cropped_mask.unsqueeze(1)), dim=1)
|
||||
res_map = {"SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS}
|
||||
initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start
|
||||
if initial_w <= 0 or initial_h <= 0: raise ValueError("Cropped area has zero dimension.")
|
||||
|
||||
res_map = { "SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS }
|
||||
supported_resolutions = res_map.get(resolutions, SD_RESOLUTIONS)
|
||||
h, w = cropped_image.shape[1:3]
|
||||
current_aspect_ratio = w / h
|
||||
diffs = [(abs(res[0] / res[1] - current_aspect_ratio), res) for res in supported_resolutions]
|
||||
min_diff = min(diffs, key=lambda x: x[0])[0]
|
||||
close_res = [res for diff, res in diffs if diff <= min_diff + aspect_ratio_tolerance]
|
||||
target_res = max(close_res, key=lambda r: r[0] * r[1])
|
||||
scale = max(target_res[0] / w, target_res[1] / h)
|
||||
new_w, new_h = int(w * scale), int(h * scale)
|
||||
upscaled_rgba = F.interpolate(rgba_bchw, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
||||
y1, x1 = (new_h - target_res[1]) // 2, (new_w - target_res[0]) // 2
|
||||
final_rgba_bchw = upscaled_rgba[:, :, y1:y1 + target_res[1], x1:x1 + target_res[0]]
|
||||
final_rgba_bhwc = final_rgba_bchw.permute(0, 2, 3, 1)
|
||||
resized_image = final_rgba_bhwc[..., :3]
|
||||
resized_mask = (final_rgba_bhwc[..., 3] > 0.5).float()
|
||||
return (resized_image, resized_mask, (context["x"], context["y"], context["width"], context["height"]))
|
||||
diffs = [(abs(res[0] / res[1] - (initial_w / initial_h)), res) for res in supported_resolutions]
|
||||
target_res = min(diffs, key=lambda x: x[0])[1]
|
||||
target_ar = target_res[0] / target_res[1]
|
||||
|
||||
current_ar = initial_w / initial_h
|
||||
final_x, final_y = float(clamped_x_start), float(clamped_y_start)
|
||||
final_w, final_h = float(initial_w), float(initial_h)
|
||||
|
||||
class CompositeCroppedAndFittedInpaintResult:
|
||||
if current_ar > target_ar:
|
||||
final_w = initial_h * target_ar
|
||||
final_x += (initial_w - final_w) / 2
|
||||
else:
|
||||
final_h = initial_w / target_ar
|
||||
final_y += (initial_h - final_h) / 2
|
||||
|
||||
final_x, final_y, final_w, final_h = int(final_x), int(final_y), int(final_w), int(final_h)
|
||||
|
||||
cropped_image = image[:, final_y:final_y + final_h, final_x:final_x + final_w]
|
||||
cropped_mask = mask[:, final_y:final_y + final_h, final_x:final_x + final_w]
|
||||
|
||||
resized_image = F.interpolate(cropped_image.permute(0,3,1,2), size=(target_res[1], target_res[0]), mode="bilinear", align_corners=False).permute(0,2,3,1)
|
||||
resized_mask = F.interpolate(cropped_mask.unsqueeze(1), size=(target_res[1], target_res[0]), mode="nearest").squeeze(1)
|
||||
|
||||
composite_context = CompositeContext(x=final_x, y=final_y, width=final_w, height=final_h)
|
||||
return (resized_image, resized_mask, composite_context)
|
||||
|
||||
class CompositeCroppedAndFittedInpaintResult(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMBO[INT]",), }}
|
||||
return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMPOSITE_CONTEXT",),}}
|
||||
|
||||
RETURN_TYPES, FUNCTION, CATEGORY = ("IMAGE",), "composite_result", "inpaint"
|
||||
|
||||
def composite_result(self, source_image: torch.Tensor, source_mask: MaskBatch, inpainted_image: torch.Tensor, composite_context: tuple[int, ...]):
|
||||
x, y, width, height = composite_context
|
||||
target_size = (height, width)
|
||||
def composite_result(self, source_image: ImageBatch, source_mask: MaskBatch, inpainted_image: ImageBatch, composite_context: CompositeContext):
|
||||
context_x, context_y, context_w, context_h = composite_context
|
||||
|
||||
resized_inpainted_image = F.interpolate(inpainted_image.permute(0, 3, 1, 2), size=target_size, mode="bilinear", align_corners=False)
|
||||
resized_inpainted = F.interpolate(
|
||||
inpainted_image.permute(0, 3, 1, 2),
|
||||
size=(context_h, context_w),
|
||||
mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
# FIX: The logic for cropping the original mask was flawed.
|
||||
# This simpler approach directly crops the relevant section of the original source_mask.
|
||||
# It correctly handles negative coordinates from the overflow case.
|
||||
crop_x_start = max(0, x)
|
||||
crop_y_start = max(0, y)
|
||||
crop_x_end = min(source_image.shape[2], x + width)
|
||||
crop_y_end = min(source_image.shape[1], y + height)
|
||||
|
||||
# The mask for compositing is a direct, high-resolution crop of the source mask.
|
||||
final_compositing_mask = source_mask[:, crop_y_start:crop_y_end, crop_x_start:crop_x_end]
|
||||
|
||||
destination_image = source_image.clone().permute(0, 3, 1, 2)
|
||||
|
||||
# We now pass our perfectly cropped high-res mask to the composite function.
|
||||
# Note that the `composite` function handles placing this at the correct sub-region.
|
||||
final_image_permuted = composite(destination=destination_image, source=resized_inpainted_image, x=x, y=y, mask=final_compositing_mask)
|
||||
|
||||
return (final_image_permuted.permute(0, 2, 3, 1),)
|
||||
final_image = composite(
|
||||
destination=source_image.clone().permute(0, 3, 1, 2),
|
||||
source=resized_inpainted,
|
||||
x=context_x,
|
||||
y=context_y,
|
||||
mask=source_mask
|
||||
)
|
||||
|
||||
return (final_image.permute(0, 2, 3, 1),)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"}
|
||||
@ -1,61 +1,67 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# Assuming the node definitions are in a file named 'inpaint_nodes.py'
|
||||
from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult, parse_margin
|
||||
from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult
|
||||
|
||||
|
||||
def create_circle_mask(height, width, center_y, center_x, radius):
|
||||
Y, X = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
|
||||
distance = torch.sqrt((Y - center_y)**2 + (X - center_x)**2)
|
||||
return (distance <= radius).float().unsqueeze(0)
|
||||
distance = torch.sqrt((Y - center_y) ** 2 + (X - center_x) ** 2)
|
||||
return (distance >= radius).float().unsqueeze(0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image() -> torch.Tensor:
|
||||
gradient = torch.linspace(0, 1, 256).view(1, -1, 1, 1)
|
||||
return gradient.expand(1, 256, 256, 3)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_1024() -> torch.Tensor:
|
||||
gradient = torch.linspace(0, 1, 1024).view(1, -1, 1, 1)
|
||||
return gradient.expand(1, 1024, 1024, 3)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rect_mask() -> torch.Tensor:
|
||||
mask = torch.zeros(1, 256, 256)
|
||||
mask[:, 100:150, 80:180] = 1.0
|
||||
mask = torch.ones(1, 256, 256)
|
||||
mask[:, 100:150, 80:180] = 0.0
|
||||
return mask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def circle_mask() -> torch.Tensor:
|
||||
return create_circle_mask(256, 256, center_y=128, center_x=128, radius=50)
|
||||
|
||||
def test_crop_and_fit_overflow(sample_image, rect_mask):
|
||||
"""Tests the overflow logic by placing the mask at an edge."""
|
||||
node = CropAndFitInpaintToDiffusionSize()
|
||||
edge_mask = torch.zeros_like(rect_mask)
|
||||
edge_mask[:, :20, :50] = 1.0
|
||||
_, _, ctx_no_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=False)
|
||||
assert ctx_no_overflow == (0, 0, 80, 50)
|
||||
img, _, ctx_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=True)
|
||||
assert ctx_overflow == (-30, -30, 110, 80)
|
||||
assert torch.allclose(img[0, 5, 5, :], torch.tensor([0.5, 0.5, 0.5]), atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("mask_fixture, margin, overflow", [
|
||||
("rect_mask", "16", False),
|
||||
("circle_mask", "32", False),
|
||||
("rect_mask", "64", True),
|
||||
("circle_mask", "0", False),
|
||||
def test_crop_and_fit_edge_clamp(sample_image):
|
||||
node = CropAndFitInpaintToDiffusionSize()
|
||||
edge_mask = torch.zeros(1, 256, 256)
|
||||
edge_mask[:, :20, :50] = 1.0
|
||||
|
||||
_, _, context = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30")
|
||||
|
||||
target_aspect_ratio = 1.0 # For SD1.5, the only valid resolution is 512x512
|
||||
actual_aspect_ratio = context.width / context.height
|
||||
assert abs(actual_aspect_ratio - target_aspect_ratio) < 1e-4
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mask_fixture, margin", [
|
||||
("rect_mask", "16"),
|
||||
("circle_mask", "32"),
|
||||
("circle_mask", "0"),
|
||||
])
|
||||
def test_end_to_end_composition(request, sample_image, mask_fixture, margin, overflow):
|
||||
"""Performs a full round-trip test of both nodes."""
|
||||
def test_end_to_end_composition(request, sample_image, mask_fixture, margin):
|
||||
mask = request.getfixturevalue(mask_fixture)
|
||||
crop_node = CropAndFitInpaintToDiffusionSize()
|
||||
composite_node = CompositeCroppedAndFittedInpaintResult()
|
||||
|
||||
# The resized mask from the first node is not needed for compositing.
|
||||
cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin, overflow)
|
||||
cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin)
|
||||
|
||||
h, w = cropped_img.shape[1:3]
|
||||
blue_color = torch.tensor([0.1, 0.2, 0.9]).view(1, 1, 1, 3)
|
||||
inpainted_sim = blue_color.expand(1, h, w, 3)
|
||||
|
||||
# FIX: Pass the original, high-resolution mask as `source_mask`.
|
||||
final_image, = composite_node.composite_result(
|
||||
source_image=sample_image,
|
||||
source_mask=mask,
|
||||
@ -66,5 +72,41 @@ def test_end_to_end_composition(request, sample_image, mask_fixture, margin, ove
|
||||
assert final_image.shape == sample_image.shape
|
||||
|
||||
bool_mask = mask.squeeze(0).bool()
|
||||
|
||||
assert torch.allclose(final_image[0][bool_mask], blue_color.squeeze(), atol=1e-2)
|
||||
assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask], atol=1e-2)
|
||||
assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask])
|
||||
|
||||
|
||||
def test_wide_ideogram_composite(image_1024):
|
||||
"""Tests the wide margin scenario. The node logic correctly chooses 1536x512."""
|
||||
source_image = image_1024
|
||||
mask = torch.zeros(1, 1024, 1024)
|
||||
mask[:, 900:932, 950:982] = 1.0
|
||||
|
||||
crop_node = CropAndFitInpaintToDiffusionSize()
|
||||
composite_node = CompositeCroppedAndFittedInpaintResult()
|
||||
|
||||
margin = "64 64 64 400"
|
||||
|
||||
cropped_img, _, context = crop_node.crop_and_fit(source_image, mask, "Ideogram", margin)
|
||||
assert cropped_img.shape[1:3] == (512, 1536)
|
||||
|
||||
green_color = torch.tensor([0.1, 0.9, 0.2]).view(1, 1, 1, 3)
|
||||
inpainted_sim = green_color.expand(1, 512, 1536, 3)
|
||||
|
||||
final_image, = composite_node.composite_result(
|
||||
source_image=source_image,
|
||||
source_mask=mask,
|
||||
inpainted_image=inpainted_sim,
|
||||
composite_context=context
|
||||
)
|
||||
|
||||
assert final_image.shape == source_image.shape
|
||||
|
||||
bool_mask = mask.squeeze(0).bool()
|
||||
|
||||
final_pixels = final_image[0][bool_mask]
|
||||
assert torch.all(final_pixels[:, 1] > final_pixels[:, 0])
|
||||
assert torch.all(final_pixels[:, 1] > final_pixels[:, 2])
|
||||
|
||||
assert torch.allclose(final_image[0, 916, 940, :], source_image[0, 916, 940, :])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user