diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index cbc9bf158..33af1595f 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -266,11 +266,7 @@ class Rotate: "max": 360, "step": 0.1 }), - "resample": ([ - "Nearest Neighbor", - "Bilinear", - "Bicubic", - ],), + "resample": (["nearest neighbor", "bilinear", "bicubic"],), "expand": (["disabled", "enabled"],), "center_x": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}), "center_y": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}), @@ -290,9 +286,9 @@ class Rotate: batch_size, height, width, _ = image.shape resamplers = { - "Nearest Neighbor": Image.Resampling.NEAREST, - "Bilinear": Image.Resampling.BILINEAR, - "Bicubic": Image.Resampling.BICUBIC, + "nearest neighbor": Image.Resampling.NEAREST, + "bilinear": Image.Resampling.BILINEAR, + "bicubic": Image.Resampling.BICUBIC, } tensor_image = image[0] @@ -430,6 +426,7 @@ class Composite: "image_b": ("IMAGE",), "x": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}), "y": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}), + "resample": (["nearest neighbor", "box", "bilinear", "bicubic", "hamming", "lanczos"],), }, "optional": { "mask": ("MASK",), @@ -441,7 +438,16 @@ class Composite: CATEGORY = "image/postprocessing" - def composite(self, image_a: torch.Tensor, image_b: torch.Tensor, x: int, y: int, mask: torch.Tensor = None): + def composite(self, image_a: torch.Tensor, image_b: torch.Tensor, x: int, y: int, resample: str, mask: torch.Tensor = None): + resamplers = { + "nearest neighbor": Image.Resampling.NEAREST, + "bilinear": Image.Resampling.BILINEAR, + "bicubic": Image.Resampling.BICUBIC, + "box": Image.Resampling.BOX, + "hamming": Image.Resampling.HAMMING, + "lanczos": Image.Resampling.LANCZOS, + } + batch_size, height, width, _ = image_a.shape result = torch.zeros_like(image_a) @@ -455,6 +461,8 @@ class Composite: pil_image_a = Image.fromarray(img_a, mode='RGB') pil_image_b = Image.fromarray(img_b, mode='RGB') pil_image_mask = Image.fromarray(img_mask, mode='L') + if pil_image_mask.size != pil_image_b.size: + pil_image_mask = pil_image_mask.resize(pil_image_b.size, resamplers[resample]) pil_image_a.paste(pil_image_b, (x, y), pil_image_mask)