diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 817069615..ec3395060 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -270,6 +270,7 @@ class Rotate: "Bilinear", "Bicubic", ],), + "expand": (["disabled", "enabled"],), "fill_color": ("STRING", {"default": "#000000"}), }, } @@ -279,9 +280,8 @@ class Rotate: CATEGORY = "image/postprocessing" - def rotate(self, image: torch.Tensor, angle: int, resample: str, fill_color: str): - batch_size, height, width, _ = image.shape - result = torch.zeros_like(image) + def rotate(self, image: torch.Tensor, angle: int, resample: str, expand: str, fill_color: str): + batch_size, _, _, _ = image.shape resamplers = { "Nearest Neighbor": Image.Resampling.NEAREST, @@ -289,29 +289,30 @@ class Rotate: "Bicubic": Image.Resampling.BICUBIC, } - for b in range(batch_size): - tensor_image = image[b] - img = (tensor_image * 255).to(torch.uint8).numpy() - pil_image = Image.fromarray(img, mode='RGB') - - fill_color = fill_color or "#000000" + tensor_image = image[0] + img = (tensor_image * 255).to(torch.uint8).numpy() + pil_image = Image.fromarray(img, mode='RGB') + + expand = True if expand == "enabled" else False + fill_color = fill_color or "#000000" - def parse_palette(color_str): - if re.match(r'^#[a-fA-F0-9]{6}$', color_str) or color_str.lower() in ImageColor.colormap: - return ImageColor.getrgb(color_str) + def parse_palette(color_str): + if re.match(r'^#[a-fA-F0-9]{6}$', color_str) or color_str.lower() in ImageColor.colormap: + return ImageColor.getrgb(color_str) - color_rgb = re.match(r'^\(?(\d{1,3}),(\d{1,3}),(\d{1,3})\)?$', color_str) - if color_rgb and int(color_rgb.group(1)) <= 255 and int(color_rgb.group(2)) <= 255 and int(color_rgb.group(3)) <= 255: - return tuple(map(int, re.findall(r'\d{1,3}', color_str))) - else: - raise ValueError(f"Invalid color format: {color_str}") + color_rgb = re.match(r'^\(?(\d{1,3}),(\d{1,3}),(\d{1,3})\)?$', color_str) + if color_rgb and int(color_rgb.group(1)) <= 255 and int(color_rgb.group(2)) <= 255 and int(color_rgb.group(3)) <= 255: + return tuple(map(int, re.findall(r'\d{1,3}', color_str))) + else: + raise ValueError(f"Invalid color format: {color_str}") - color = fill_color.replace(" ", "") - color = parse_palette(color) - rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=False, fillcolor=color) - - rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255 - result[b] = rotated_array + color = fill_color.replace(" ", "") + color = parse_palette(color) + rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, fillcolor=color) + height, width = rotated_image.size + result = torch.zeros(batch_size, height, width, 3) + rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255 + result[0] = rotated_array return (result,)