Fix rotate node with batch_size > 1

This commit is contained in:
missionfloyd 2023-04-27 20:45:58 -06:00
parent 7b64c38d52
commit 03dabd2f56

View File

@ -290,13 +290,6 @@ class Rotate:
"bilinear": Image.Resampling.BILINEAR,
"bicubic": Image.Resampling.BICUBIC,
}
tensor_image = image[0]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
expand = True if expand == "enabled" else False
fill_color = fill_color or "#000000"
def parse_palette(color_str):
if re.match(r'^#[a-fA-F0-9]{6}$', color_str) or color_str.lower() in ImageColor.colormap:
@ -308,16 +301,24 @@ class Rotate:
else:
raise ValueError(f"Invalid color format: {color_str}")
expand = True if expand == "enabled" else False
fill_color = fill_color or "#000000"
center = (width / 2, height / 2) if center_of_image == "enabled" else (center_x, center_y)
translate = (translate_x, translate_y)
color = fill_color.replace(" ", "")
color = parse_palette(color)
rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, center=center, translate=translate, fillcolor=color)
result_width, result_height = rotated_image.size
result_width, result_height = Image.new("RGB", (width, height)).rotate(angle=angle, expand=expand, center=center, translate=translate).size
result = torch.zeros(batch_size, result_height, result_width, 3)
rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255
result[0] = rotated_array
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, center=center, translate=translate, fillcolor=color)
rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255
result[b] = rotated_array
return (result,)