diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index b8585b53f..b58267c12 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -320,6 +320,82 @@ class Rotate: return (result,) +class GetChannel: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "channel": ([ + "Red", + "Green", + "Blue", + ],), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "getchannel" + + CATEGORY = "image/postprocessing" + + def getchannel(self, image: torch.Tensor, channel: str): + batch_size, height, width, _ = image.shape + result = torch.zeros_like(image) + + for b in range(batch_size): + tensor_image = image[b] + img = (tensor_image * 255).to(torch.uint8).numpy() + pil_image = Image.fromarray(img, mode='RGB') + + output_image = pil_image.getchannel(channel[0]) + + output_array = torch.tensor(np.array(output_image.convert("RGB"))).float() / 255 + result[b] = output_array + + return (result,) + +class Split: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") + FUNCTION = "split" + + CATEGORY = "image/postprocessing" + + def split(self, image: torch.Tensor): + batch_size, height, width, _ = image.shape + result_r = torch.zeros_like(image) + result_g = torch.zeros_like(image) + result_b = torch.zeros_like(image) + + for b in range(batch_size): + tensor_image = image[b] + img = (tensor_image * 255).to(torch.uint8).numpy() + pil_image = Image.fromarray(img, mode='RGB') + + output_r, output_g, output_b = pil_image.split() + + output_array_r = torch.tensor(np.array(output_r.convert("RGB"))).float() / 255 + output_array_g = torch.tensor(np.array(output_g.convert("RGB"))).float() / 255 + output_array_b = torch.tensor(np.array(output_b.convert("RGB"))).float() / 255 + result_r[b], result_g[b], result_b[b] = output_array_r, output_array_g, output_array_b + + return (result_r, result_g, result_b) + + NODE_CLASS_MAPPINGS = { "ImageBlend": Blend, "ImageBlur": Blur, @@ -327,4 +403,6 @@ NODE_CLASS_MAPPINGS = { "ImageSharpen": Sharpen, "ImageTranspose": Transpose, "ImageRotate": Rotate, + "ImageGetChannel": GetChannel, + "ImageSplit": Split, }