diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 89618de10..1d554422b 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -261,7 +261,7 @@ class Rotate: "image": ("IMAGE",), "angle": ("FLOAT", { "default": 0, - "min": 0, + "min": -360, "max": 360, "step": 0.1 }), @@ -271,6 +271,8 @@ class Rotate: "Bicubic", ],), "expand": (["disabled", "enabled"],), + "center": ("STRING", {"default": None}), + "translate": ("STRING", {"default": None}), "fill_color": ("STRING", {"default": "#000000"}), }, } @@ -280,7 +282,7 @@ class Rotate: CATEGORY = "image/postprocessing" - def rotate(self, image: torch.Tensor, angle: int, resample: str, expand: str, fill_color: str): + def rotate(self, image: torch.Tensor, angle: int, resample: str, expand: str, center: str, translate: str, fill_color: str): batch_size, _, _, _ = image.shape resamplers = { @@ -295,7 +297,7 @@ class Rotate: expand = True if expand == "enabled" else False fill_color = fill_color or "#000000" - + def parse_palette(color_str): if re.match(r'^#[a-fA-F0-9]{6}$', color_str) or color_str.lower() in ImageColor.colormap: return ImageColor.getrgb(color_str) @@ -306,11 +308,20 @@ class Rotate: else: raise ValueError(f"Invalid color format: {color_str}") + def parse_coord(coord_str): + if re.match(r'^\s*-?\d+\s*,\s*-?\d+\s*$', coord_str): + return tuple(map(int, re.findall(r'-?\d+', coord_str))) + else: + raise ValueError(f"Invalid coordinates: {coord_str}") + + center = parse_coord(center) if center else tuple(map(lambda x: x / 2, pil_image.size)) + translate = parse_coord(translate) if translate else (0, 0) + color = fill_color.replace(" ", "") color = parse_palette(color) - rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, fillcolor=color) + rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, center=center, translate=translate, fillcolor=color) height, width = rotated_image.size - result = torch.zeros(batch_size, height, width, 3) + result = torch.zeros(batch_size, width, height, 3) rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255 result[0] = rotated_array