Add center, translate options to rotate

Center and translate take X,Y coordinates
Change min to -360
This commit is contained in:
missionfloyd 2023-04-11 21:42:19 -06:00 committed by GitHub
parent fcc561261d
commit f5474109a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -261,7 +261,7 @@ class Rotate:
"image": ("IMAGE",), "image": ("IMAGE",),
"angle": ("FLOAT", { "angle": ("FLOAT", {
"default": 0, "default": 0,
"min": 0, "min": -360,
"max": 360, "max": 360,
"step": 0.1 "step": 0.1
}), }),
@ -271,6 +271,8 @@ class Rotate:
"Bicubic", "Bicubic",
],), ],),
"expand": (["disabled", "enabled"],), "expand": (["disabled", "enabled"],),
"center": ("STRING", {"default": None}),
"translate": ("STRING", {"default": None}),
"fill_color": ("STRING", {"default": "#000000"}), "fill_color": ("STRING", {"default": "#000000"}),
}, },
} }
@ -280,7 +282,7 @@ class Rotate:
CATEGORY = "image/postprocessing" CATEGORY = "image/postprocessing"
def rotate(self, image: torch.Tensor, angle: int, resample: str, expand: str, fill_color: str): def rotate(self, image: torch.Tensor, angle: int, resample: str, expand: str, center: str, translate: str, fill_color: str):
batch_size, _, _, _ = image.shape batch_size, _, _, _ = image.shape
resamplers = { resamplers = {
@ -306,11 +308,20 @@ class Rotate:
else: else:
raise ValueError(f"Invalid color format: {color_str}") raise ValueError(f"Invalid color format: {color_str}")
def parse_coord(coord_str):
if re.match(r'^\s*-?\d+\s*,\s*-?\d+\s*$', coord_str):
return tuple(map(int, re.findall(r'-?\d+', coord_str)))
else:
raise ValueError(f"Invalid coordinates: {coord_str}")
center = parse_coord(center) if center else tuple(map(lambda x: x / 2, pil_image.size))
translate = parse_coord(translate) if translate else (0, 0)
color = fill_color.replace(" ", "") color = fill_color.replace(" ", "")
color = parse_palette(color) color = parse_palette(color)
rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, fillcolor=color) rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, center=center, translate=translate, fillcolor=color)
height, width = rotated_image.size height, width = rotated_image.size
result = torch.zeros(batch_size, height, width, 3) result = torch.zeros(batch_size, width, height, 3)
rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255 rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255
result[0] = rotated_array result[0] = rotated_array