diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index b1e0d4666..ed54ccc57 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -304,10 +304,23 @@ Optional spacing can be added between images. image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled" ).movedim(1, -1) + color_map = { + "white": 1.0, + "black": 0.0, + "red": (1.0, 0.0, 0.0), + "green": (0.0, 1.0, 0.0), + "blue": (0.0, 0.0, 1.0), + } + + color_val = color_map[spacing_color] + # When not matching sizes, pad to align non-concat dimensions if not match_image_size: h1, w1 = image1.shape[1:3] h2, w2 = image2.shape[1:3] + pad_value = 0.0 + if not isinstance(color_val, tuple): + pad_value = color_val if direction in ["left", "right"]: # For horizontal concat, pad heights to match @@ -316,11 +329,11 @@ Optional spacing can be added between images. if h1 < target_h: pad_h = target_h - h1 pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 - image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) if h2 < target_h: pad_h = target_h - h2 pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 - image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) else: # up, down # For vertical concat, pad widths to match if w1 != w2: @@ -328,11 +341,11 @@ Optional spacing can be added between images. if w1 < target_w: pad_w = target_w - w1 pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 - image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) if w2 < target_w: pad_w = target_w - w2 pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 - image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) # Ensure same number of channels if image1.shape[-1] != image2.shape[-1]: @@ -366,15 +379,6 @@ Optional spacing can be added between images. if spacing_width > 0: spacing_width = spacing_width + (spacing_width % 2) # Ensure even - color_map = { - "white": 1.0, - "black": 0.0, - "red": (1.0, 0.0, 0.0), - "green": (0.0, 1.0, 0.0), - "blue": (0.0, 0.0, 1.0), - } - color_val = color_map[spacing_color] - if direction in ["left", "right"]: spacing_shape = ( image1.shape[0], @@ -410,6 +414,62 @@ Optional spacing can be added between images. concat_dim = 2 if direction in ["left", "right"] else 1 return (torch.cat(images, dim=concat_dim),) +class ResizeAndPadImage: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "target_width": ("INT", { + "default": 512, + "min": 1, + "max": MAX_RESOLUTION, + "step": 1 + }), + "target_height": ("INT", { + "default": 512, + "min": 1, + "max": MAX_RESOLUTION, + "step": 1 + }), + "padding_color": (["white", "black"],), + "interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "resize_and_pad" + CATEGORY = "image/transform" + + def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation): + batch_size, orig_height, orig_width, channels = image.shape + + scale_w = target_width / orig_width + scale_h = target_height / orig_height + scale = min(scale_w, scale_h) + + new_width = int(orig_width * scale) + new_height = int(orig_height * scale) + + image_permuted = image.permute(0, 3, 1, 2) + + resized = comfy.utils.common_upscale(image_permuted, new_width, new_height, interpolation, "disabled") + + pad_value = 0.0 if padding_color == "black" else 1.0 + padded = torch.full( + (batch_size, channels, target_height, target_width), + pad_value, + dtype=image.dtype, + device=image.device + ) + + y_offset = (target_height - new_height) // 2 + x_offset = (target_width - new_width) // 2 + + padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized + + output = padded.permute(0, 2, 3, 1) + return (output,) class SaveSVGNode: """ @@ -532,5 +592,6 @@ NODE_CLASS_MAPPINGS = { "SaveAnimatedPNG": SaveAnimatedPNG, "SaveSVGNode": SaveSVGNode, "ImageStitch": ImageStitch, + "ResizeAndPadImage": ResizeAndPadImage, "GetImageSize": GetImageSize, }