From bc54b69c5986002aa3b393181a28ae75c29fd6b0 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Mon, 10 Apr 2023 20:23:15 -0600 Subject: [PATCH] Change transpose to torch --- comfy_extras/nodes_post_processing.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 9e7e8b034..817069615 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -236,24 +236,17 @@ class Transpose: result = torch.permute(result, (0, 2, 1, 3)) methods = { - "Flip horizontal": Image.Transpose.FLIP_LEFT_RIGHT, - "Flip vertical": Image.Transpose.FLIP_TOP_BOTTOM, - "Rotate 90°": Image.Transpose.ROTATE_90, - "Rotate 180°": Image.Transpose.ROTATE_180, - "Rotate 270°": Image.Transpose.ROTATE_270, - "Transpose": Image.Transpose.TRANSPOSE, - "Transverse": Image.Transpose.TRANSVERSE, + "Flip horizontal": (lambda x: torch.fliplr(x)), + "Flip vertical": (lambda x: torch.flipud(x)), + "Rotate 90°": (lambda x: torch.rot90(x)), + "Rotate 180°": (lambda x: torch.rot90(x, 2)), + "Rotate 270°": (lambda x: torch.rot90(x, 3)), + "Transpose": (lambda x: torch.transpose(x, 0, 1)), + "Transverse": (lambda x: torch.rot90(torch.transpose(x, 0, 1), 2)), } for b in range(batch_size): - tensor_image = image[b] - img = (tensor_image * 255).to(torch.uint8).numpy() - pil_image = Image.fromarray(img, mode='RGB') - - transposed_image = pil_image.transpose(methods[method]) - - transposed_array = torch.tensor(np.array(transposed_image.convert("RGB"))).float() / 255 - result[b] = transposed_array + result[b] = methods[method](image[b]) return (result,)