mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
refactor and cleanup of inpainting nodes
This commit is contained in:
parent
d4c9d5c748
commit
564e14ea97
@ -90,4 +90,16 @@ WAN_VIDEO_14B_EXTENDED_RESOLUTIONS = [
|
||||
(544, 704)
|
||||
]
|
||||
|
||||
RESOLUTION_NAMES = ["SDXL/SD3/Flux", "SD1.5", "LTXV", "Ideogram", "Cosmos", "HunyuanVideo", "WAN 14b", "WAN 1.3b", "WAN 14b with extras"]
|
||||
RESOLUTION_MAP = {
|
||||
"SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS,
|
||||
"SD1.5": SD_RESOLUTIONS,
|
||||
"LTXV": LTVX_RESOLUTIONS,
|
||||
"Ideogram": IDEOGRAM_RESOLUTIONS,
|
||||
"Cosmos": COSMOS_RESOLUTIONS,
|
||||
"HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS,
|
||||
"WAN 14b": WAN_VIDEO_14B_RESOLUTIONS,
|
||||
"WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS,
|
||||
"WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS,
|
||||
}
|
||||
|
||||
RESOLUTION_NAMES = list(RESOLUTION_MAP.keys())
|
||||
@ -16,7 +16,7 @@ from comfy.nodes.common import MAX_RESOLUTION
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
from comfy_extras.constants.resolutions import SDXL_SD3_FLUX_RESOLUTIONS, LTVX_RESOLUTIONS, SD_RESOLUTIONS, \
|
||||
IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS, WAN_VIDEO_14B_RESOLUTIONS, \
|
||||
WAN_VIDEO_1_3B_RESOLUTIONS, WAN_VIDEO_14B_EXTENDED_RESOLUTIONS, RESOLUTION_NAMES
|
||||
WAN_VIDEO_1_3B_RESOLUTIONS, WAN_VIDEO_14B_EXTENDED_RESOLUTIONS, RESOLUTION_NAMES, RESOLUTION_MAP
|
||||
|
||||
|
||||
def levels_adjustment(image: ImageBatch, black_level: float = 0.0, mid_level: float = 0.5, white_level: float = 1.0, clip: bool = True) -> ImageBatch:
|
||||
@ -287,24 +287,7 @@ class ImageResize:
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def resize_image(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]:
|
||||
if resolutions == "SDXL/SD3/Flux":
|
||||
supported_resolutions = SDXL_SD3_FLUX_RESOLUTIONS
|
||||
elif resolutions == "LTXV":
|
||||
supported_resolutions = LTVX_RESOLUTIONS
|
||||
elif resolutions == "Ideogram":
|
||||
supported_resolutions = IDEOGRAM_RESOLUTIONS
|
||||
elif resolutions == "Cosmos":
|
||||
supported_resolutions = COSMOS_RESOLUTIONS
|
||||
elif resolutions == "HunyuanVideo":
|
||||
supported_resolutions = HUNYUAN_VIDEO_RESOLUTIONS
|
||||
elif resolutions == "WAN 14b":
|
||||
supported_resolutions = WAN_VIDEO_14B_RESOLUTIONS
|
||||
elif resolutions == "WAN 1.3b":
|
||||
supported_resolutions = WAN_VIDEO_1_3B_RESOLUTIONS
|
||||
elif resolutions == "WAN 14b with extras":
|
||||
supported_resolutions = WAN_VIDEO_14B_EXTENDED_RESOLUTIONS
|
||||
else:
|
||||
supported_resolutions = SD_RESOLUTIONS
|
||||
supported_resolutions = RESOLUTION_MAP.get(resolutions, SD_RESOLUTIONS)
|
||||
return self.resize_image_with_supported_resolutions(image, resize_mode, supported_resolutions, interpolation, aspect_ratio_tolerance=aspect_ratio_tolerance)
|
||||
|
||||
def resize_image_with_supported_resolutions(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], supported_resolutions: list[tuple[int, int]], interpolation: str, aspect_ratio_tolerance=0.05) -> tuple[RGBImageBatch]:
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from comfy.component_model.tensor_types import MaskBatch, ImageBatch
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
from comfy_extras.constants.resolutions import (
|
||||
RESOLUTION_NAMES, SDXL_SD3_FLUX_RESOLUTIONS, SD_RESOLUTIONS, LTVX_RESOLUTIONS,
|
||||
IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS,
|
||||
WAN_VIDEO_14B_RESOLUTIONS, WAN_VIDEO_1_3B_RESOLUTIONS,
|
||||
WAN_VIDEO_14B_EXTENDED_RESOLUTIONS
|
||||
)
|
||||
from comfy_extras.constants.resolutions import RESOLUTION_MAP, SD_RESOLUTIONS, RESOLUTION_NAMES
|
||||
|
||||
|
||||
class CompositeContext(NamedTuple):
|
||||
x: int
|
||||
@ -17,8 +14,8 @@ class CompositeContext(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None):
|
||||
"""A robust function to composite a source tensor onto a destination tensor."""
|
||||
|
||||
def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask: Optional[MaskBatch] = None) -> ImageBatch:
|
||||
source = source.to(destination.device)
|
||||
if source.shape[0] != destination.shape[0]:
|
||||
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
|
||||
@ -31,8 +28,10 @@ def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask:
|
||||
mask = torch.ones_like(source)
|
||||
else:
|
||||
mask = mask.to(destination.device, copy=True)
|
||||
if mask.dim() == 2: mask = mask.unsqueeze(0)
|
||||
if mask.dim() == 3: mask = mask.unsqueeze(1)
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.dim() == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
if mask.shape[0] != source.shape[0]:
|
||||
mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1)
|
||||
|
||||
@ -46,40 +45,50 @@ def composite(destination: ImageBatch, source: ImageBatch, x: int, y: int, mask:
|
||||
|
||||
destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right]
|
||||
source_portion = source[:, :, src_top:src_bottom, src_left:src_right]
|
||||
|
||||
# The mask must be cropped to the region of interest on the destination.
|
||||
mask_portion = mask[:, :, dest_top:dest_bottom, dest_left:dest_right]
|
||||
|
||||
blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion))
|
||||
destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion
|
||||
return destination
|
||||
|
||||
|
||||
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
|
||||
parts = [int(p) for p in margin_str.strip().split()]
|
||||
if len(parts) == 1: return parts[0], parts[0], parts[0], parts[0]
|
||||
if len(parts) == 2: return parts[0], parts[1], parts[0], parts[1]
|
||||
if len(parts) == 3: return parts[0], parts[1], parts[2], parts[1]
|
||||
if len(parts) == 4: return parts[0], parts[1], parts[2], parts[3]
|
||||
raise ValueError("Invalid margin format.")
|
||||
match len(parts):
|
||||
case 1:
|
||||
return parts[0], parts[0], parts[0], parts[0]
|
||||
case 2:
|
||||
return parts[0], parts[1], parts[0], parts[1]
|
||||
case 3:
|
||||
return parts[0], parts[1], parts[2], parts[1]
|
||||
case 4:
|
||||
return parts[0], parts[1], parts[2], parts[3]
|
||||
case _:
|
||||
raise ValueError("Invalid margin format.")
|
||||
|
||||
|
||||
class CropAndFitInpaintToDiffusionSize(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",), "mask": ("MASK",),
|
||||
"resolutions": (RESOLUTION_NAMES, {"default": "SD1.5"}),
|
||||
"margin": ("STRING", {"default": "64"}),
|
||||
}}
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",), "mask": ("MASK",),
|
||||
"resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}),
|
||||
"margin": ("STRING", {"default": "64"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "COMPOSITE_CONTEXT")
|
||||
RETURN_NAMES = ("image", "mask", "composite_context")
|
||||
FUNCTION = "crop_and_fit"
|
||||
CATEGORY = "inpaint"
|
||||
|
||||
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, aspect_ratio_tolerance=0.05):
|
||||
if mask.max() <= 0: raise ValueError("Mask is empty.")
|
||||
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str):
|
||||
if mask.max() <= 0:
|
||||
raise ValueError("Mask is empty.")
|
||||
mask_coords = torch.nonzero(mask)
|
||||
if mask_coords.numel() == 0: raise ValueError("Mask is empty.")
|
||||
if mask_coords.numel() == 0:
|
||||
raise ValueError("Mask is empty.")
|
||||
|
||||
y_coords, x_coords = mask_coords[:, 1], mask_coords[:, 2]
|
||||
y_min, x_min = y_coords.min().item(), x_coords.min().item()
|
||||
@ -94,10 +103,10 @@ class CropAndFitInpaintToDiffusionSize(CustomNode):
|
||||
clamped_x_end, clamped_y_end = min(img_w, x_end_expanded), min(img_h, y_end_expanded)
|
||||
|
||||
initial_w, initial_h = clamped_x_end - clamped_x_start, clamped_y_end - clamped_y_start
|
||||
if initial_w <= 0 or initial_h <= 0: raise ValueError("Cropped area has zero dimension.")
|
||||
if initial_w <= 0 or initial_h <= 0:
|
||||
raise ValueError("Cropped area has zero dimension.")
|
||||
|
||||
res_map = { "SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS }
|
||||
supported_resolutions = res_map.get(resolutions, SD_RESOLUTIONS)
|
||||
supported_resolutions = RESOLUTION_MAP.get(resolutions, SD_RESOLUTIONS)
|
||||
diffs = [(abs(res[0] / res[1] - (initial_w / initial_h)), res) for res in supported_resolutions]
|
||||
target_res = min(diffs, key=lambda x: x[0])[1]
|
||||
target_ar = target_res[0] / target_res[1]
|
||||
@ -118,20 +127,30 @@ class CropAndFitInpaintToDiffusionSize(CustomNode):
|
||||
cropped_image = image[:, final_y:final_y + final_h, final_x:final_x + final_w]
|
||||
cropped_mask = mask[:, final_y:final_y + final_h, final_x:final_x + final_w]
|
||||
|
||||
resized_image = F.interpolate(cropped_image.permute(0,3,1,2), size=(target_res[1], target_res[0]), mode="bilinear", align_corners=False).permute(0,2,3,1)
|
||||
resized_image = F.interpolate(cropped_image.permute(0, 3, 1, 2), size=(target_res[1], target_res[0]), mode="bilinear", align_corners=False).permute(0, 2, 3, 1)
|
||||
resized_mask = F.interpolate(cropped_mask.unsqueeze(1), size=(target_res[1], target_res[0]), mode="nearest").squeeze(1)
|
||||
|
||||
composite_context = CompositeContext(x=final_x, y=final_y, width=final_w, height=final_h)
|
||||
return (resized_image, resized_mask, composite_context)
|
||||
|
||||
|
||||
class CompositeCroppedAndFittedInpaintResult(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMPOSITE_CONTEXT",),}}
|
||||
return {
|
||||
"required": {
|
||||
"source_image": ("IMAGE",),
|
||||
"source_mask": ("MASK",),
|
||||
"inpainted_image": ("IMAGE",),
|
||||
"composite_context": ("COMPOSITE_CONTEXT",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES, FUNCTION, CATEGORY = ("IMAGE",), "composite_result", "inpaint"
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "composite_result"
|
||||
CATEGORY = "inpaint"
|
||||
|
||||
def composite_result(self, source_image: ImageBatch, source_mask: MaskBatch, inpainted_image: ImageBatch, composite_context: CompositeContext):
|
||||
def composite_result(self, source_image: ImageBatch, source_mask: MaskBatch, inpainted_image: ImageBatch, composite_context: CompositeContext) -> tuple[ImageBatch]:
|
||||
context_x, context_y, context_w, context_h = composite_context
|
||||
|
||||
resized_inpainted = F.interpolate(
|
||||
@ -148,7 +167,14 @@ class CompositeCroppedAndFittedInpaintResult(CustomNode):
|
||||
mask=source_mask
|
||||
)
|
||||
|
||||
return (final_image.permute(0, 2, 3, 1),)
|
||||
return final_image.permute(0, 2, 3, 1),
|
||||
|
||||
NODE_CLASS_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize,
|
||||
"CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region",
|
||||
"CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result",
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user