From fcc561261d8f01395397c9e46caa54264d2e30f1 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 11 Apr 2023 20:17:00 -0600 Subject: [PATCH] Use torch for extract, split, merge --- comfy_extras/nodes_post_processing.py | 44 +++++++-------------------- 1 file changed, 11 insertions(+), 33 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ec3395060..89618de10 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -326,9 +326,9 @@ class GetChannel: "required": { "image": ("IMAGE",), "channel": ([ - "Red", - "Green", - "Blue", + "red", + "green", + "blue", ],), }, } @@ -341,16 +341,11 @@ class GetChannel: def getchannel(self, image: torch.Tensor, channel: str): batch_size, height, width, _ = image.shape result = torch.zeros_like(image) + channels = ["red", "green", "blue"] for b in range(batch_size): - tensor_image = image[b] - img = (tensor_image * 255).to(torch.uint8).numpy() - pil_image = Image.fromarray(img, mode='RGB') - - output_image = pil_image.getchannel(channel[0]) - - output_array = torch.tensor(np.array(output_image.convert("RGB"))).float() / 255 - result[b] = output_array + channel_out = image[b, :, :, channels.index(channel)] + result[b] = channel_out[:, :, None].expand(-1, -1, 3) return (result,) @@ -380,16 +375,8 @@ class Split: result_b = torch.zeros_like(image) for b in range(batch_size): - tensor_image = image[b] - img = (tensor_image * 255).to(torch.uint8).numpy() - pil_image = Image.fromarray(img, mode='RGB') - - output_r, output_g, output_b = pil_image.split() - - output_array_r = torch.tensor(np.array(output_r.convert("RGB"))).float() / 255 - output_array_g = torch.tensor(np.array(output_g.convert("RGB"))).float() / 255 - output_array_b = torch.tensor(np.array(output_b.convert("RGB"))).float() / 255 - result_r[b], result_g[b], result_b[b] = output_array_r, output_array_g, output_array_b + channels = torch.chunk(image[b], 3, 2) + result_r[b], result_g[b], result_b[b] = [x.expand(-1, -1, 3) for x in channels] return (result_r, result_g, result_b) @@ -415,19 +402,10 @@ class Merge: def merge(self, red: torch.Tensor, green: torch.Tensor, blue: torch.Tensor): batch_size, height, width, _ = red.shape result = torch.zeros_like(red) - + images = [red, green, blue] for b in range(batch_size): - img_r = (red[b] * 255).to(torch.uint8).numpy() - img_g = (green[b] * 255).to(torch.uint8).numpy() - img_b = (blue[b] * 255).to(torch.uint8).numpy() - pil_image_r = Image.fromarray(img_r, mode='RGB').convert("L") - pil_image_g = Image.fromarray(img_g, mode='RGB').convert("L") - pil_image_b = Image.fromarray(img_b, mode='RGB').convert("L") - - output_image = Image.merge("RGB", (pil_image_r, pil_image_g, pil_image_b)) - - output_array = torch.tensor(np.array(output_image.convert("RGB"))).float() / 255 - result[b] = output_array + for i in range(3): + result[b, :, :, i] = images[i][b, :, :, 0] return (result,)