diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 293df28d7..01b3bf546 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -6,6 +6,7 @@ import re import comfy.utils +MAX_RESOLUTION=8192 class Blend: def __init__(self): @@ -271,8 +272,11 @@ class Rotate: "Bicubic", ],), "expand": (["disabled", "enabled"],), - "center": ("STRING", {"default": None}), - "translate": ("STRING", {"default": None}), + "center_x": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}), + "center_y": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}), + "center_of_image": (["disabled", "enabled"],), + "translate_x": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}), + "translate_y": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}), "fill_color": ("STRING", {"default": "#000000"}), }, } @@ -282,8 +286,8 @@ class Rotate: CATEGORY = "image/postprocessing" - def rotate(self, image: torch.Tensor, angle: int, resample: str, expand: str, center: str, translate: str, fill_color: str): - batch_size, _, _, _ = image.shape + def rotate(self, image: torch.Tensor, angle: int, resample: str, expand: str, center_x: int, center_y: int, center_of_image: str, translate_x: int, translate_y: int, fill_color: str): + batch_size, height, width, _ = image.shape resamplers = { "Nearest Neighbor": Image.Resampling.NEAREST, @@ -308,20 +312,16 @@ class Rotate: else: raise ValueError(f"Invalid color format: {color_str}") - def parse_coord(coord_str): - if re.match(r'^\s*-?\d+\s*,\s*-?\d+\s*$', coord_str): - return tuple(map(int, re.findall(r'-?\d+', coord_str))) - else: - raise ValueError(f"Invalid coordinates: {coord_str}") - - center = parse_coord(center) if center else tuple(map(lambda x: x / 2, pil_image.size)) - translate = parse_coord(translate) if translate else (0, 0) + center = (center_x, center_y) if center_of_image == "disabled" else (width / 2, height / 2) + print(center_of_image) + print(center) + translate = (translate_x, translate_y) color = fill_color.replace(" ", "") color = parse_palette(color) rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, center=center, translate=translate, fillcolor=color) - width, height = rotated_image.size - result = torch.zeros(batch_size, height, width, 3) + result_width, result_height = rotated_image.size + result = torch.zeros(batch_size, result_height, result_width, 3) rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255 result[0] = rotated_array