mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
152 lines
8.5 KiB
Python
152 lines
8.5 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from comfy.component_model.tensor_types import MaskBatch
|
|
from comfy_extras.constants.resolutions import (
|
|
RESOLUTION_NAMES, SDXL_SD3_FLUX_RESOLUTIONS, SD_RESOLUTIONS, LTVX_RESOLUTIONS,
|
|
IDEOGRAM_RESOLUTIONS, COSMOS_RESOLUTIONS, HUNYUAN_VIDEO_RESOLUTIONS,
|
|
WAN_VIDEO_14B_RESOLUTIONS, WAN_VIDEO_1_3B_RESOLUTIONS,
|
|
WAN_VIDEO_14B_EXTENDED_RESOLUTIONS
|
|
)
|
|
|
|
|
|
def composite(destination, source, x, y, mask=None, multiplier=1, resize_source=False):
|
|
source = source.to(destination.device)
|
|
if resize_source:
|
|
source = F.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
|
if source.shape[0] != destination.shape[0]:
|
|
source = source.repeat(destination.shape[0] // source.shape[0], 1, 1, 1)
|
|
|
|
x, y = int(x), int(y)
|
|
left, top = x, y
|
|
right, bottom = left + source.shape[3], top + source.shape[2]
|
|
|
|
if mask is None:
|
|
mask = torch.ones_like(source)
|
|
else:
|
|
mask = mask.to(destination.device, copy=True)
|
|
if mask.dim() == 2: mask = mask.unsqueeze(0)
|
|
if mask.dim() == 3: mask = mask.unsqueeze(1)
|
|
if mask.shape[0] != source.shape[0]:
|
|
mask = mask.repeat(source.shape[0] // mask.shape[0], 1, 1, 1)
|
|
|
|
dest_left, dest_top = max(0, left), max(0, top)
|
|
dest_right, dest_bottom = min(destination.shape[3], right), min(destination.shape[2], bottom)
|
|
|
|
if dest_right <= dest_left or dest_bottom <= dest_top: return destination
|
|
|
|
src_left, src_top = dest_left - left, dest_top - top
|
|
src_right, src_bottom = dest_right - left, dest_bottom
|
|
|
|
destination_portion = destination[:, :, dest_top:dest_bottom, dest_left:dest_right]
|
|
source_portion = source[:, :, src_top:src_bottom, src_left:src_right]
|
|
mask_portion = mask[:, :, src_top:src_bottom, src_left:src_right]
|
|
|
|
blended_portion = (source_portion * mask_portion) + (destination_portion * (1.0 - mask_portion))
|
|
destination[:, :, dest_top:dest_bottom, dest_left:dest_right] = blended_portion
|
|
return destination
|
|
|
|
|
|
def parse_margin(margin_str: str) -> tuple[int, int, int, int]:
|
|
parts = [int(p) for p in margin_str.strip().split()]
|
|
if len(parts) == 1: return parts[0], parts[0], parts[0], parts[0]
|
|
if len(parts) == 2: return parts[0], parts[1], parts[0], parts[1]
|
|
if len(parts) == 3: return parts[0], parts[1], parts[2], parts[1]
|
|
if len(parts) == 4: return parts[0], parts[1], parts[2], parts[3]
|
|
raise ValueError("Invalid margin format.")
|
|
|
|
|
|
class CropAndFitInpaintToDiffusionSize:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {"required": {"image": ("IMAGE",), "mask": ("MASK",), "resolutions": (RESOLUTION_NAMES, {"default": RESOLUTION_NAMES[0]}), "margin": ("STRING", {"default": "64"}), "overflow": ("BOOLEAN", {"default": True}), }}
|
|
|
|
RETURN_TYPES, RETURN_NAMES, FUNCTION, CATEGORY = ("IMAGE", "MASK", "COMBO[INT]"), ("image", "mask", "composite_context"), "crop_and_fit", "inpaint"
|
|
|
|
def crop_and_fit(self, image: torch.Tensor, mask: MaskBatch, resolutions: str, margin: str, overflow: bool, aspect_ratio_tolerance=0.05):
|
|
if mask.max() <= 0: raise ValueError("Mask is empty.")
|
|
mask_coords = torch.nonzero(mask[0]);
|
|
if mask_coords.numel() == 0: raise ValueError("Mask is empty.")
|
|
y_min, x_min = mask_coords.min(dim=0).values;
|
|
y_max, x_max = mask_coords.max(dim=0).values
|
|
top_m, right_m, bottom_m, left_m = parse_margin(margin)
|
|
x_start_init, y_start_init = x_min.item() - left_m, y_min.item() - top_m
|
|
x_end_init, y_end_init = x_max.item() + 1 + right_m, y_max.item() + 1 + bottom_m
|
|
img_h, img_w = image.shape[1:3]
|
|
pad_image, pad_mask = image, mask
|
|
x_start_crop, y_start_crop = x_start_init, y_start_init
|
|
x_end_crop, y_end_crop = x_end_init, y_end_init
|
|
pad_l, pad_t = -min(0, x_start_init), -min(0, y_start_init)
|
|
pad_r, pad_b = max(0, x_end_init - img_w), max(0, y_end_init - img_h)
|
|
if any([pad_l, pad_t, pad_r, pad_b]) and overflow:
|
|
padding = (pad_l, pad_r, pad_t, pad_b)
|
|
pad_image = F.pad(image.permute(0, 3, 1, 2), padding, "constant", 0.5).permute(0, 2, 3, 1)
|
|
pad_mask = F.pad(mask.unsqueeze(1), padding, "constant", 0).squeeze(1)
|
|
x_start_crop += pad_l;
|
|
y_start_crop += pad_t;
|
|
x_end_crop += pad_l;
|
|
y_end_crop += pad_t
|
|
else:
|
|
x_start_crop, y_start_crop = max(0, x_start_init), max(0, y_start_init)
|
|
x_end_crop, y_end_crop = min(img_w, x_end_init), min(img_h, y_end_init)
|
|
composite_x, composite_y = (x_start_init if overflow else x_start_crop), (y_start_init if overflow else y_start_crop)
|
|
cropped_image = pad_image[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop, :]
|
|
cropped_mask = pad_mask[:, y_start_crop:y_end_crop, x_start_crop:x_end_crop]
|
|
context = {"x": composite_x, "y": composite_y, "width": cropped_image.shape[2], "height": cropped_image.shape[1]}
|
|
|
|
rgba_bchw = torch.cat((cropped_image.permute(0, 3, 1, 2), cropped_mask.unsqueeze(1)), dim=1)
|
|
res_map = {"SDXL/SD3/Flux": SDXL_SD3_FLUX_RESOLUTIONS, "SD1.5": SD_RESOLUTIONS, "LTXV": LTVX_RESOLUTIONS, "Ideogram": IDEOGRAM_RESOLUTIONS, "Cosmos": COSMOS_RESOLUTIONS, "HunyuanVideo": HUNYUAN_VIDEO_RESOLUTIONS, "WAN 14b": WAN_VIDEO_14B_RESOLUTIONS, "WAN 1.3b": WAN_VIDEO_1_3B_RESOLUTIONS, "WAN 14b with extras": WAN_VIDEO_14B_EXTENDED_RESOLUTIONS}
|
|
supported_resolutions = res_map.get(resolutions, SD_RESOLUTIONS)
|
|
h, w = cropped_image.shape[1:3]
|
|
current_aspect_ratio = w / h
|
|
diffs = [(abs(res[0] / res[1] - current_aspect_ratio), res) for res in supported_resolutions]
|
|
min_diff = min(diffs, key=lambda x: x[0])[0]
|
|
close_res = [res for diff, res in diffs if diff <= min_diff + aspect_ratio_tolerance]
|
|
target_res = max(close_res, key=lambda r: r[0] * r[1])
|
|
scale = max(target_res[0] / w, target_res[1] / h)
|
|
new_w, new_h = int(w * scale), int(h * scale)
|
|
upscaled_rgba = F.interpolate(rgba_bchw, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
|
y1, x1 = (new_h - target_res[1]) // 2, (new_w - target_res[0]) // 2
|
|
final_rgba_bchw = upscaled_rgba[:, :, y1:y1 + target_res[1], x1:x1 + target_res[0]]
|
|
final_rgba_bhwc = final_rgba_bchw.permute(0, 2, 3, 1)
|
|
resized_image = final_rgba_bhwc[..., :3]
|
|
resized_mask = (final_rgba_bhwc[..., 3] > 0.5).float()
|
|
return (resized_image, resized_mask, (context["x"], context["y"], context["width"], context["height"]))
|
|
|
|
|
|
class CompositeCroppedAndFittedInpaintResult:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {"source_image": ("IMAGE",), "source_mask": ("MASK",), "inpainted_image": ("IMAGE",), "composite_context": ("COMBO[INT]",), }}
|
|
|
|
RETURN_TYPES, FUNCTION, CATEGORY = ("IMAGE",), "composite_result", "inpaint"
|
|
|
|
def composite_result(self, source_image: torch.Tensor, source_mask: MaskBatch, inpainted_image: torch.Tensor, composite_context: tuple[int, ...]):
|
|
x, y, width, height = composite_context
|
|
target_size = (height, width)
|
|
|
|
resized_inpainted_image = F.interpolate(inpainted_image.permute(0, 3, 1, 2), size=target_size, mode="bilinear", align_corners=False)
|
|
|
|
# FIX: The logic for cropping the original mask was flawed.
|
|
# This simpler approach directly crops the relevant section of the original source_mask.
|
|
# It correctly handles negative coordinates from the overflow case.
|
|
crop_x_start = max(0, x)
|
|
crop_y_start = max(0, y)
|
|
crop_x_end = min(source_image.shape[2], x + width)
|
|
crop_y_end = min(source_image.shape[1], y + height)
|
|
|
|
# The mask for compositing is a direct, high-resolution crop of the source mask.
|
|
final_compositing_mask = source_mask[:, crop_y_start:crop_y_end, crop_x_start:crop_x_end]
|
|
|
|
destination_image = source_image.clone().permute(0, 3, 1, 2)
|
|
|
|
# We now pass our perfectly cropped high-res mask to the composite function.
|
|
# Note that the `composite` function handles placing this at the correct sub-region.
|
|
final_image_permuted = composite(destination=destination_image, source=resized_inpainted_image, x=x, y=y, mask=final_compositing_mask)
|
|
|
|
return (final_image_permuted.permute(0, 2, 3, 1),)
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": CropAndFitInpaintToDiffusionSize, "CompositeCroppedAndFittedInpaintResult": CompositeCroppedAndFittedInpaintResult}
|
|
NODE_DISPLAY_NAME_MAPPINGS = {"CropAndFitInpaintToDiffusionSize": "Crop & Fit Inpaint Region", "CompositeCroppedAndFittedInpaintResult": "Composite Inpaint Result"}
|