From 564e14ea973b58a87e97eff40655ec3617946d19 Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Mon, 9 Jun 2025 13:55:32 -0700 Subject: [PATCH] refactor and cleanup of inpainting nodes --- comfy_extras/constants/resolutions.py | 14 +++- comfy_extras/nodes/nodes_images.py | 21 +----- comfy_extras/nodes/nodes_inpainting.py | 98 ++++++++++++++++---------- 3 files changed, 77 insertions(+), 56 deletions(-) diff --git a/comfy_extras/constants/resolutions.py b/comfy_extras/constants/resolutions.py index 542b9e338..1d2bd3b40 100644 --- a/comfy_extras/constants/resolutions.py +++ b/comfy_extras/constants/resolutions.py @@ -90,4 +90,16 @@ WAN_VIDEO_14B_EXTENDED_RESOLUTIONS = [ (544, 704) ] -RESOLUTION_NAMES = ["SDXL/SD3/Flux", "SD1.5", "LTXV", "Ideogram", "Cosmos", "HunyuanVideo", "WAN 14b", "WAN 1.3b", "WAN 14b with extras"] \ No newline at end of file +RESOLUTION_MAP = { + "SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, + "SD1.5": SD_RESOLUTIONS, + "LTXV": LTVX_RESOLUTIONS, + "Ideogram": IDEOGRAM_RESOLUTIONS, + "Cosmos": COSMOS_RESOLUTIONS, + "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, + "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, + "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, + "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS, +} + +RESOLUTION_NAMES = list(RESOLUTION_MAP.keys()) \ No newline at end of file diff --git a/comfy_extras/nodes/nodes_images.py b/comfy_extras/nodes/nodes_images.py index 2129e4cb1..386cbe28b 100644 --- a/comfy_extras/nodes/nodes_images.py +++ b/comfy_extras/nodes/nodes_images.py @@ -16,7 +16,7 @@ from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.package_typing import CustomNode from comfy_extras.constants.resolutions import SDXL_SD3_FLUX_RESOLUTIONS, LTVX_RESOLUTIONS, SD_RESOLUTIONS, \ IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS, WAN_VIDEO_14B_RESOLUTIONS, \ - WAN_VIDEO_1_3B_RESOLUTIONS, WAN_VIDEO_14B_EXTENDED_RESOLUTIONS, RESOLUTION_NAMES + WAN_VIDEO_1_3B_RESOLUTIONS, WAN_VIDEO_14B_EXTENDED_RESOLUTIONS, RESOLUTION_NAMES, RESOLUTION_MAP def levels_adjustment(image: ImageBatch, black_level: float = 0.0, mid_level: float = 0.5, white_level: float = 1.0, clip: bool = True) -> ImageBatch: @@ -287,24 +287,7 @@ class ImageResize: CATEGORY = "image/transform" def resize_image(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]: - if resolutions == "SDXL/SD3/Flux": - supported_resolutions = SDXL_SD3_FLUX_RESOLUTIONS - elif resolutions == "LTXV": - supported_resolutions = LTVX_RESOLUTIONS - elif resolutions == "Ideogram": - supported_resolutions = IDEOGRAM_RESOLUTIONS - elif resolutions == "Cosmos": - supported_resolutions = COSMOS_RESOLUTIONS - elif resolutions == "HunyuanVideo": - supported_resolutions = HUNYUAN_VIDEO_RESOLUTIONS - elif resolutions == "WAN 14b": - supported_resolutions = WAN_VIDEO_14B_RESOLUTIONS - elif resolutions == "WAN 1.3b": - supported_resolutions = WAN_VIDEO_1_3B_RESOLUTIONS - elif resolutions == "WAN 14b with extras": - supported_resolutions = WAN_VIDEO_14B_EXTENDED_RESOLUTIONS - else: - supported_resolutions = SD_RESOLUTIONS + supported_resolutions = RESOLUTION_MAP.get(resolutions, SD_RESOLUTIONS) return self.resize_image_with_supported_resolutions(image, resize_mode, supported_resolutions, interpolation, aspect_ratio_tolerance=aspect_ratio_tolerance) def resize_image_with_supported_resolutions(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], supported_resolutions: list[tuple[int, int]], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]: diff --git a/comfy_extras/nodes/nodes_inpainting.py b/comfy_extras/nodes/nodes_inpainting.py index 6720f5963..30c2a63ab 100644 --- a/comfy_extras/nodes/nodes_inpainting.py +++ b/comfy_extras/nodes/nodes_inpainting.py @@ -1,15 +1,12 @@ +from typing import NamedTuple, Optional + import torch import torch.nn.functional as F -from typing import NamedTuple, Optional from comfy.component_model.tensor_types import MaskBatch, ImageBatch from comfy.nodes.package_typing import CustomNode -from comfy_extras.constants.resolutions import ( - RESOLUTION_NAMES, SDXL_SD3_FLUX_RESOLUTIONS, SD_RESOLUTIONS, LTVX_RESOLUTIONS, - IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS, - WAN_VIDEO_14B_RESOLUTIONS, WAN_VIDEO_1_3B_RESOLUTIONS, - WAN_VIDEO_14B_EXTENDED_RESOLUTIONS -) +from comfy_extras.constants.resolutions import RESOLUTION_MAP, SD_RESOLUTIONS, RESOLUTION_NAMES + class CompositeContext(NamedTuple): x: int @@ -17,8 +14,8 @@ class CompositeContext(NamedTuple): width: int height: int -def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None): - """A robust function to composite a source tensor onto a destination tensor.""" + +def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None) -> ImageBatch: source = source.to(destination.device) if source.shape[0] != destination.shape[0]: source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1) @@ -31,8 +28,10 @@ def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: mask = torch.ones_like(source) else: mask = mask.to(destination.device, copy=True) - if mask.dim() == 2: mask = mask.unsqueeze(0) - if mask.dim() == 3: mask = mask.unsqueeze(1) + if mask.dim() == 2: + mask = mask.unsqueeze(0) + if mask.dim() == 3: + mask = mask.unsqueeze(1) if mask.shape[0] != source.shape[0]: mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1) @@ -46,40 +45,50 @@ def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right] source_portion = source[:, :, src_top:src_bottom, src_left:src_right] - - # The mask must be cropped to the region of interest on the destination. mask_portion = mask[:, :, dest_top:dest_bottom, dest_left:dest_right] blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion)) destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion return destination + def parse_margin(margin_str: str) -> tuple[int, int, int, int]: parts = [int(p) for p in margin_str.strip().split()] - if len(parts) == 1: return parts[0], parts[0], parts[0], parts[0] - if len(parts) == 2: return parts[0], parts[1], parts[0], parts[1] - if len(parts) == 3: return parts[0], parts[1], parts[2], parts[1] - if len(parts) == 4: return parts[0], parts[1], parts[2], parts[3] - raise ValueError("Invalid margin format.") + match len(parts): + case 1: + return parts[0], parts[0], parts[0], parts[0] + case 2: + return parts[0], parts[1], parts[0], parts[1] + case 3: + return parts[0], parts[1], parts[2], parts[1] + case 4: + return parts[0], parts[1], parts[2], parts[3] + case _: + raise ValueError("Invalid margin format.") + class CropAndFitInpaintToDiffusionSize(CustomNode): @classmethod def INPUT_TYPES(cls): - return {"required": { - "image": ("IMAGE",), "mask": ("MASK",), - "resolutions": (RESOLUTION_NAMES, {"default": "SD1.5"}), - "margin": ("STRING", {"default": "64"}), - }} + return { + "required": { + "image": ("IMAGE",), "mask": ("MASK",), + "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), + "margin": ("STRING", {"default": "64"}), + } + } RETURN_TYPES = ("IMAGE", "MASK", "COMPOSITE_CONTEXT") RETURN_NAMES = ("image", "mask", "composite_context") FUNCTION = "crop_and_fit" CATEGORY = "inpaint" - def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, aspect_ratio_tolerance=0.05): - if mask.max() <= 0: raise ValueError("Mask is empty.") + def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str): + if mask.max() <= 0: + raise ValueError("Mask is empty.") mask_coords = torch.nonzero(mask) - if mask_coords.numel() == 0: raise ValueError("Mask is empty.") + if mask_coords.numel() == 0: + raise ValueError("Mask is empty.") y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2] y_min, x_min = y_coords.min().item(), x_coords.min().item() @@ -94,10 +103,10 @@ class CropAndFitInpaintToDiffusionSize(CustomNode): clamped_x_end, clamped_y_end = min(img_w, x_end_expanded), min(img_h, y_end_expanded) initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start - if initial_w <= 0 or initial_h <= 0: raise ValueError("Cropped area has zero dimension.") + if initial_w <= 0 or initial_h <= 0: + raise ValueError("Cropped area has zero dimension.") - res_map = { "SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS } - supported_resolutions = res_map.get(resolutions, SD_RESOLUTIONS) + supported_resolutions = RESOLUTION_MAP.get(resolutions, SD_RESOLUTIONS) diffs = [(abs(res[0] / res[1] - (initial_w / initial_h)), res) for res in supported_resolutions] target_res = min(diffs, key=lambda x: x[0])[1] target_ar = target_res[0] / target_res[1] @@ -118,20 +127,30 @@ class CropAndFitInpaintToDiffusionSize(CustomNode): cropped_image = image[:, final_y:final_y + final_h, final_x:final_x + final_w] cropped_mask = mask[:, final_y:final_y + final_h, final_x:final_x + final_w] - resized_image = F.interpolate(cropped_image.permute(0,3,1,2), size=(target_res[1], target_res[0]), mode="bilinear", align_corners=False).permute(0,2,3,1) + resized_image = F.interpolate(cropped_image.permute(0, 3, 1, 2), size=(target_res[1], target_res[0]), mode="bilinear", align_corners=False).permute(0, 2, 3, 1) resized_mask = F.interpolate(cropped_mask.unsqueeze(1), size=(target_res[1], target_res[0]), mode="nearest").squeeze(1) composite_context = CompositeContext(x=final_x, y=final_y, width=final_w, height=final_h) return (resized_image, resized_mask, composite_context) + class CompositeCroppedAndFittedInpaintResult(CustomNode): @classmethod def INPUT_TYPES(s): - return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMPOSITE_CONTEXT",),}} + return { + "required": { + "source_image": ("IMAGE",), + "source_mask": ("MASK",), + "inpainted_image": ("IMAGE",), + "composite_context": ("COMPOSITE_CONTEXT",), + } + } - RETURN_TYPES, FUNCTION, CATEGORY = ("IMAGE",), "composite_result", "inpaint" + RETURN_TYPES = ("IMAGE",) + FUNCTION = "composite_result" + CATEGORY = "inpaint" - def composite_result(self, source_image: ImageBatch, source_mask: MaskBatch, inpainted_image: ImageBatch, composite_context: CompositeContext): + def composite_result(self, source_image: ImageBatch, source_mask: MaskBatch, inpainted_image: ImageBatch, composite_context: CompositeContext) -> tuple[ImageBatch]: context_x, context_y, context_w, context_h = composite_context resized_inpainted = F.interpolate( @@ -148,7 +167,14 @@ class CompositeCroppedAndFittedInpaintResult(CustomNode): mask=source_mask ) - return (final_image.permute(0, 2, 3, 1),) + return final_image.permute(0, 2, 3, 1), -NODE_CLASS_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult} -NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"} \ No newline at end of file + +NODE_CLASS_MAPPINGS = { + "CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, + "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", + "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result", +}