inpainting nodes

This commit is contained in:
Benjamin Berman 2025-06-07 10:19:05 -07:00
parent 396a2ef3d3
commit d4c9d5c748
3 changed files with 154 additions and 109 deletions

View File

@ -10,7 +10,7 @@ from PIL.PngImagePlugin import PngInfo
from comfy import utils
from comfy.cli_args import args
from comfy.cmd import folder_paths
from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch
from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch, RGBAImageBatch
from comfy.nodes.base_nodes import ImageScale
from comfy.nodes.common import MAX_RESOLUTION
from comfy.nodes.package_typing import CustomNode
@ -286,7 +286,7 @@ class ImageResize:
FUNCTION = "resize_image"
CATEGORY = "image/transform"
def resize_image(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]:
def resize_image(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]:
if resolutions == "SDXL/SD3/Flux":
supported_resolutions = SDXL_SD3_FLUX_RESOLUTIONS
elif resolutions == "LTXV":
@ -307,7 +307,7 @@ class ImageResize:
supported_resolutions = SD_RESOLUTIONS
return self.resize_image_with_supported_resolutions(image, resize_mode, supported_resolutions, interpolation, aspect_ratio_tolerance=aspect_ratio_tolerance)
def resize_image_with_supported_resolutions(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], supported_resolutions: list[tuple[int, int]], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]:
def resize_image_with_supported_resolutions(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], supported_resolutions: list[tuple[int, int]], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]:
resized_images = []
for img in image:
h, w = img.shape[:2]

View File

@ -1,7 +1,9 @@
import torch
import torch.nn.functional as F
from typing import NamedTuple, Optional
from comfy.component_model.tensor_types import MaskBatch
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
from comfy.nodes.package_typing import CustomNode
from comfy_extras.constants.resolutions import (
RESOLUTION_NAMES, SDXL_SD3_FLUX_RESOLUTIONS, SD_RESOLUTIONS, LTVX_RESOLUTIONS,
IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS,
@ -9,11 +11,15 @@ from comfy_extras.constants.resolutions import (
WAN_VIDEO_14B_EXTENDED_RESOLUTIONS
)
class CompositeContext(NamedTuple):
x: int
y: int
width: int
height: int
def composite(destination, source, x, y, mask=None, multiplier=1, resize_source=False):
def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None):
"""A robust function to composite a source tensor onto a destination tensor."""
source = source.to(destination.device)
if resize_source:
source = F.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
if source.shape[0] != destination.shape[0]:
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
@ -40,13 +46,14 @@ def composite(destination, source, x, y, mask=None, multiplier=1, resize_source=
destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right]
source_portion = source[:, :, src_top:src_bottom, src_left:src_right]
mask_portion = mask[:, :, src_top:src_bottom, src_left:src_right]
# The mask must be cropped to the region of interest on the destination.
mask_portion = mask[:, :, dest_top:dest_bottom, dest_left:dest_right]
blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion))
destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion
return destination
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
parts = [int(p) for p in margin_str.strip().split()]
if len(parts) == 1: return parts[0], parts[0], parts[0], parts[0]
@ -55,97 +62,93 @@ def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
if len(parts) == 4: return parts[0], parts[1], parts[2], parts[3]
raise ValueError("Invalid margin format.")
class CropAndFitInpaintToDiffusionSize:
class CropAndFitInpaintToDiffusionSize(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {"required": {"image": ("IMAGE",), "mask": ("MASK",), "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), "margin": ("STRING", {"default": "64"}), "overflow": ("BOOLEAN", {"default": True}), }}
return {"required": {
"image": ("IMAGE",), "mask": ("MASK",),
"resolutions": (RESOLUTION_NAMES, {"default": "SD1.5"}),
"margin": ("STRING", {"default": "64"}),
}}
RETURN_TYPES, RETURN_NAMES, FUNCTION, CATEGORY = ("IMAGE", "MASK", "COMBO[INT]"), ("image", "mask", "composite_context"), "crop_and_fit", "inpaint"
RETURN_TYPES = ("IMAGE", "MASK", "COMPOSITE_CONTEXT")
RETURN_NAMES = ("image", "mask", "composite_context")
FUNCTION = "crop_and_fit"
CATEGORY = "inpaint"
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, overflow: bool, aspect_ratio_tolerance=0.05):
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, aspect_ratio_tolerance=0.05):
if mask.max() <= 0: raise ValueError("Mask is empty.")
mask_coords = torch.nonzero(mask[0]);
mask_coords = torch.nonzero(mask)
if mask_coords.numel() == 0: raise ValueError("Mask is empty.")
y_min, x_min = mask_coords.min(dim=0).values;
y_max, x_max = mask_coords.max(dim=0).values
y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2]
y_min, x_min = y_coords.min().item(), x_coords.min().item()
y_max, x_max = y_coords.max().item(), x_coords.max().item()
top_m, right_m, bottom_m, left_m = parse_margin(margin)
x_start_init, y_start_init = x_min.item() - left_m, y_min.item() - top_m
x_end_init, y_end_init = x_max.item() + 1 + right_m, y_max.item() + 1 + bottom_m
x_start_expanded, y_start_expanded = x_min - left_m, y_min - top_m
x_end_expanded, y_end_expanded = x_max + 1 + right_m, y_max + 1 + bottom_m
img_h, img_w = image.shape[1:3]
pad_image, pad_mask = image, mask
x_start_crop, y_start_crop = x_start_init, y_start_init
x_end_crop, y_end_crop = x_end_init, y_end_init
pad_l, pad_t = -min(0, x_start_init), -min(0, y_start_init)
pad_r, pad_b = max(0, x_end_init - img_w), max(0, y_end_init - img_h)
if any([pad_l, pad_t, pad_r, pad_b]) and overflow:
padding = (pad_l, pad_r, pad_t, pad_b)
pad_image = F.pad(image.permute(0, 3, 1, 2), padding, "constant", 0.5).permute(0, 2, 3, 1)
pad_mask = F.pad(mask.unsqueeze(1), padding, "constant", 0).squeeze(1)
x_start_crop += pad_l;
y_start_crop += pad_t;
x_end_crop += pad_l;
y_end_crop += pad_t
else:
x_start_crop, y_start_crop = max(0, x_start_init), max(0, y_start_init)
x_end_crop, y_end_crop = min(img_w, x_end_init), min(img_h, y_end_init)
composite_x, composite_y = (x_start_init if overflow else x_start_crop), (y_start_init if overflow else y_start_crop)
cropped_image = pad_image[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop, :]
cropped_mask = pad_mask[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop]
context = {"x": composite_x, "y": composite_y, "width": cropped_image.shape[2], "height": cropped_image.shape[1]}
clamped_x_start, clamped_y_start = max(0, x_start_expanded), max(0, y_start_expanded)
clamped_x_end, clamped_y_end = min(img_w, x_end_expanded), min(img_h, y_end_expanded)
rgba_bchw = torch.cat((cropped_image.permute(0, 3, 1, 2), cropped_mask.unsqueeze(1)), dim=1)
res_map = {"SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS}
initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start
if initial_w <= 0 or initial_h <= 0: raise ValueError("Cropped area has zero dimension.")
res_map = { "SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS }
supported_resolutions = res_map.get(resolutions, SD_RESOLUTIONS)
h, w = cropped_image.shape[1:3]
current_aspect_ratio = w / h
diffs = [(abs(res[0] / res[1] - current_aspect_ratio), res) for res in supported_resolutions]
min_diff = min(diffs, key=lambda x: x[0])[0]
close_res = [res for diff, res in diffs if diff <= min_diff + aspect_ratio_tolerance]
target_res = max(close_res, key=lambda r: r[0] * r[1])
scale = max(target_res[0] / w, target_res[1] / h)
new_w, new_h = int(w * scale), int(h * scale)
upscaled_rgba = F.interpolate(rgba_bchw, size=(new_h, new_w), mode="bilinear", align_corners=False)
y1, x1 = (new_h - target_res[1]) // 2, (new_w - target_res[0]) // 2
final_rgba_bchw = upscaled_rgba[:, :, y1:y1 + target_res[1], x1:x1 + target_res[0]]
final_rgba_bhwc = final_rgba_bchw.permute(0, 2, 3, 1)
resized_image = final_rgba_bhwc[..., :3]
resized_mask = (final_rgba_bhwc[..., 3] > 0.5).float()
return (resized_image, resized_mask, (context["x"], context["y"], context["width"], context["height"]))
diffs = [(abs(res[0] / res[1] - (initial_w / initial_h)), res) for res in supported_resolutions]
target_res = min(diffs, key=lambda x: x[0])[1]
target_ar = target_res[0] / target_res[1]
current_ar = initial_w / initial_h
final_x, final_y = float(clamped_x_start), float(clamped_y_start)
final_w, final_h = float(initial_w), float(initial_h)
class CompositeCroppedAndFittedInpaintResult:
if current_ar > target_ar:
final_w = initial_h * target_ar
final_x += (initial_w - final_w) / 2
else:
final_h = initial_w / target_ar
final_y += (initial_h - final_h) / 2
final_x, final_y, final_w, final_h = int(final_x), int(final_y), int(final_w), int(final_h)
cropped_image = image[:, final_y:final_y + final_h, final_x:final_x + final_w]
cropped_mask = mask[:, final_y:final_y + final_h, final_x:final_x + final_w]
resized_image = F.interpolate(cropped_image.permute(0,3,1,2), size=(target_res[1], target_res[0]), mode="bilinear", align_corners=False).permute(0,2,3,1)
resized_mask = F.interpolate(cropped_mask.unsqueeze(1), size=(target_res[1], target_res[0]), mode="nearest").squeeze(1)
composite_context = CompositeContext(x=final_x, y=final_y, width=final_w, height=final_h)
return (resized_image, resized_mask, composite_context)
class CompositeCroppedAndFittedInpaintResult(CustomNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMBO[INT]",), }}
return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMPOSITE_CONTEXT",),}}
RETURN_TYPES, FUNCTION, CATEGORY = ("IMAGE",), "composite_result", "inpaint"
def composite_result(self, source_image: torch.Tensor, source_mask: MaskBatch, inpainted_image: torch.Tensor, composite_context: tuple[int, ...]):
x, y, width, height = composite_context
target_size = (height, width)
def composite_result(self, source_image: ImageBatch, source_mask: MaskBatch, inpainted_image: ImageBatch, composite_context: CompositeContext):
context_x, context_y, context_w, context_h = composite_context
resized_inpainted_image = F.interpolate(inpainted_image.permute(0, 3, 1, 2), size=target_size, mode="bilinear", align_corners=False)
resized_inpainted = F.interpolate(
inpainted_image.permute(0, 3, 1, 2),
size=(context_h, context_w),
mode="bilinear", align_corners=False
)
# FIX: The logic for cropping the original mask was flawed.
# This simpler approach directly crops the relevant section of the original source_mask.
# It correctly handles negative coordinates from the overflow case.
crop_x_start = max(0, x)
crop_y_start = max(0, y)
crop_x_end = min(source_image.shape[2], x + width)
crop_y_end = min(source_image.shape[1], y + height)
# The mask for compositing is a direct, high-resolution crop of the source mask.
final_compositing_mask = source_mask[:, crop_y_start:crop_y_end, crop_x_start:crop_x_end]
destination_image = source_image.clone().permute(0, 3, 1, 2)
# We now pass our perfectly cropped high-res mask to the composite function.
# Note that the `composite` function handles placing this at the correct sub-region.
final_image_permuted = composite(destination=destination_image, source=resized_inpainted_image, x=x, y=y, mask=final_compositing_mask)
return (final_image_permuted.permute(0, 2, 3, 1),)
final_image = composite(
destination=source_image.clone().permute(0, 3, 1, 2),
source=resized_inpainted,
x=context_x,
y=context_y,
mask=source_mask
)
return (final_image.permute(0, 2, 3, 1),)
NODE_CLASS_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult}
NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"}
NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"}

View File

@ -1,61 +1,67 @@
import pytest
import torch
import numpy as np
# Assuming the node definitions are in a file named 'inpaint_nodes.py'
from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult, parse_margin
from comfy_extras.nodes.nodes_inpainting import CropAndFitInpaintToDiffusionSize, CompositeCroppedAndFittedInpaintResult
def create_circle_mask(height, width, center_y, center_x, radius):
Y, X = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
distance = torch.sqrt((Y - center_y)**2 + (X - center_x)**2)
return (distance <= radius).float().unsqueeze(0)
distance = torch.sqrt((Y - center_y) ** 2 + (X - center_x) ** 2)
return (distance >= radius).float().unsqueeze(0)
@pytest.fixture
def sample_image() -> torch.Tensor:
gradient = torch.linspace(0, 1, 256).view(1, -1, 1, 1)
return gradient.expand(1, 256, 256, 3)
@pytest.fixture
def image_1024() -> torch.Tensor:
gradient = torch.linspace(0, 1, 1024).view(1, -1, 1, 1)
return gradient.expand(1, 1024, 1024, 3)
@pytest.fixture
def rect_mask() -> torch.Tensor:
mask = torch.zeros(1, 256, 256)
mask[:, 100:150, 80:180] = 1.0
mask = torch.ones(1, 256, 256)
mask[:, 100:150, 80:180] = 0.0
return mask
@pytest.fixture
def circle_mask() -> torch.Tensor:
return create_circle_mask(256, 256, center_y=128, center_x=128, radius=50)
def test_crop_and_fit_overflow(sample_image, rect_mask):
"""Tests the overflow logic by placing the mask at an edge."""
node = CropAndFitInpaintToDiffusionSize()
edge_mask = torch.zeros_like(rect_mask)
edge_mask[:, :20, :50] = 1.0
_, _, ctx_no_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=False)
assert ctx_no_overflow == (0, 0, 80, 50)
img, _, ctx_overflow = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30", overflow=True)
assert ctx_overflow == (-30, -30, 110, 80)
assert torch.allclose(img[0, 5, 5, :], torch.tensor([0.5, 0.5, 0.5]), atol=1e-3)
@pytest.mark.parametrize("mask_fixture, margin, overflow", [
("rect_mask", "16", False),
("circle_mask", "32", False),
("rect_mask", "64", True),
("circle_mask", "0", False),
def test_crop_and_fit_edge_clamp(sample_image):
node = CropAndFitInpaintToDiffusionSize()
edge_mask = torch.zeros(1, 256, 256)
edge_mask[:, :20, :50] = 1.0
_, _, context = node.crop_and_fit(sample_image, edge_mask, "SD1.5", "30")
target_aspect_ratio = 1.0 # For SD1.5, the only valid resolution is 512x512
actual_aspect_ratio = context.width / context.height
assert abs(actual_aspect_ratio - target_aspect_ratio) < 1e-4
@pytest.mark.parametrize("mask_fixture, margin", [
("rect_mask", "16"),
("circle_mask", "32"),
("circle_mask", "0"),
])
def test_end_to_end_composition(request, sample_image, mask_fixture, margin, overflow):
"""Performs a full round-trip test of both nodes."""
def test_end_to_end_composition(request, sample_image, mask_fixture, margin):
mask = request.getfixturevalue(mask_fixture)
crop_node = CropAndFitInpaintToDiffusionSize()
composite_node = CompositeCroppedAndFittedInpaintResult()
# The resized mask from the first node is not needed for compositing.
cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin, overflow)
cropped_img, _, context = crop_node.crop_and_fit(sample_image, mask, "SD1.5", margin)
h, w = cropped_img.shape[1:3]
blue_color = torch.tensor([0.1, 0.2, 0.9]).view(1, 1, 1, 3)
inpainted_sim = blue_color.expand(1, h, w, 3)
# FIX: Pass the original, high-resolution mask as `source_mask`.
final_image, = composite_node.composite_result(
source_image=sample_image,
source_mask=mask,
@ -66,5 +72,41 @@ def test_end_to_end_composition(request, sample_image, mask_fixture, margin, ove
assert final_image.shape == sample_image.shape
bool_mask = mask.squeeze(0).bool()
assert torch.allclose(final_image[0][bool_mask], blue_color.squeeze(), atol=1e-2)
assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask], atol=1e-2)
assert torch.allclose(final_image[0][~bool_mask], sample_image[0][~bool_mask])
def test_wide_ideogram_composite(image_1024):
"""Tests the wide margin scenario. The node logic correctly chooses 1536x512."""
source_image = image_1024
mask = torch.zeros(1, 1024, 1024)
mask[:, 900:932, 950:982] = 1.0
crop_node = CropAndFitInpaintToDiffusionSize()
composite_node = CompositeCroppedAndFittedInpaintResult()
margin = "64 64 64 400"
cropped_img, _, context = crop_node.crop_and_fit(source_image, mask, "Ideogram", margin)
assert cropped_img.shape[1:3] == (512, 1536)
green_color = torch.tensor([0.1, 0.9, 0.2]).view(1, 1, 1, 3)
inpainted_sim = green_color.expand(1, 512, 1536, 3)
final_image, = composite_node.composite_result(
source_image=source_image,
source_mask=mask,
inpainted_image=inpainted_sim,
composite_context=context
)
assert final_image.shape == source_image.shape
bool_mask = mask.squeeze(0).bool()
final_pixels = final_image[0][bool_mask]
assert torch.all(final_pixels[:, 1] > final_pixels[:, 0])
assert torch.all(final_pixels[:, 1] > final_pixels[:, 2])
assert torch.allclose(final_image[0, 916, 940, :], source_image[0, 916, 940, :])