From d36ad5d958e8de54665b70999417525d8841a2ea Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 8 Apr 2023 20:43:43 -0600 Subject: [PATCH] Add transpose and rotate nodes --- comfy_extras/nodes_post_processing.py | 122 +++++++++++++++++++++++++- 1 file changed, 121 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ba699e2b8..b8585b53f 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -1,7 +1,8 @@ import numpy as np import torch import torch.nn.functional as F -from PIL import Image +from PIL import Image, ImageColor +import re import comfy.utils @@ -202,9 +203,128 @@ class Sharpen: return (result,) +class Transpose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "method": ([ + "Flip horizontal", + "Flip vertical", + "Rotate 90°", + "Rotate 180°", + "Rotate 270°", + "Transpose", + "Transverse", + ],), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "transpose" + + CATEGORY = "image/postprocessing" + + def transpose(self, image: torch.Tensor, method: str): + batch_size, height, width, _ = image.shape + result = torch.zeros_like(image) + + methods = { + "Flip horizontal": Image.Transpose.FLIP_LEFT_RIGHT, + "Flip vertical": Image.Transpose.FLIP_TOP_BOTTOM, + "Rotate 90°": Image.Transpose.ROTATE_90, + "Rotate 180°": Image.Transpose.ROTATE_180, + "Rotate 270°": Image.Transpose.ROTATE_270, + "Transpose": Image.Transpose.TRANSPOSE, + "Transverse": Image.Transpose.TRANSVERSE, + } + + for b in range(batch_size): + tensor_image = image[b] + img = (tensor_image * 255).to(torch.uint8).numpy() + pil_image = Image.fromarray(img, mode='RGB') + + transposed_image = pil_image.transpose(methods[method]) + + transposed_array = torch.tensor(np.array(transposed_image.convert("RGB"))).float() / 255 + result[b] = transposed_array + + return (result,) + +class Rotate: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "angle": ("FLOAT", { + "default": 0, + "min": 0, + "max": 360, + "step": 0.1 + }), + "resample": ([ + "Nearest Neighbor", + "Bilinear", + "Bicubic", + ],), + "fill_color": ("STRING", {"default": "#000000"}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "rotate" + + CATEGORY = "image/postprocessing" + + def rotate(self, image: torch.Tensor, angle: int, resample: str, fill_color: str): + batch_size, height, width, _ = image.shape + result = torch.zeros_like(image) + + resamplers = { + "Nearest Neighbor": Image.Resampling.NEAREST, + "Bilinear": Image.Resampling.BILINEAR, + "Bicubic": Image.Resampling.BICUBIC, + } + + for b in range(batch_size): + tensor_image = image[b] + img = (tensor_image * 255).to(torch.uint8).numpy() + pil_image = Image.fromarray(img, mode='RGB') + + fill_color = fill_color or "#000000" + + def parse_palette(color_str): + if re.match(r'^#[a-fA-F0-9]{6}$', color_str) or color_str.lower() in ImageColor.colormap: + return ImageColor.getrgb(color_str) + + color_rgb = re.match(r'^\(?(\d{1,3}),(\d{1,3}),(\d{1,3})\)?$', color_str) + if color_rgb and int(color_rgb.group(1)) <= 255 and int(color_rgb.group(2)) <= 255 and int(color_rgb.group(3)) <= 255: + return tuple(map(int, re.findall(r'\d{1,3}', color_str))) + else: + raise ValueError(f"Invalid color format: {color_str}") + + color = fill_color.replace(" ", "") + color = parse_palette(color) + rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=False, fillcolor=color) + + rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255 + result[b] = rotated_array + + return (result,) + NODE_CLASS_MAPPINGS = { "ImageBlend": Blend, "ImageBlur": Blur, "ImageQuantize": Quantize, "ImageSharpen": Sharpen, + "ImageTranspose": Transpose, + "ImageRotate": Rotate, }