Change transpose to torch

This commit is contained in:
missionfloyd 2023-04-10 20:23:15 -06:00 committed by GitHub
parent 70564aebb6
commit bc54b69c59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -236,24 +236,17 @@ class Transpose:
result = torch.permute(result, (0, 2, 1, 3))
methods = {
"Flip horizontal": Image.Transpose.FLIP_LEFT_RIGHT,
"Flip vertical": Image.Transpose.FLIP_TOP_BOTTOM,
"Rotate 90°": Image.Transpose.ROTATE_90,
"Rotate 180°": Image.Transpose.ROTATE_180,
"Rotate 270°": Image.Transpose.ROTATE_270,
"Transpose": Image.Transpose.TRANSPOSE,
"Transverse": Image.Transpose.TRANSVERSE,
"Flip horizontal": (lambda x: torch.fliplr(x)),
"Flip vertical": (lambda x: torch.flipud(x)),
"Rotate 90°": (lambda x: torch.rot90(x)),
"Rotate 180°": (lambda x: torch.rot90(x, 2)),
"Rotate 270°": (lambda x: torch.rot90(x, 3)),
"Transpose": (lambda x: torch.transpose(x, 0, 1)),
"Transverse": (lambda x: torch.rot90(torch.transpose(x, 0, 1), 2)),
}
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
transposed_image = pil_image.transpose(methods[method])
transposed_array = torch.tensor(np.array(transposed_image.convert("RGB"))).float() / 255
result[b] = transposed_array
result[b] = methods[method](image[b])
return (result,)