diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 3bc9fccb3..7f6ce78f8 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -185,6 +185,36 @@ class SplitImageWithAlpha(io.ComfyNode): return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas)) +class SplitImageChannels(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SplitImageChannels", + search_aliases=["extract alpha", "extract channels", "separate transparency", "split channels", "remove alpha"], + display_name="Split Image Channels", + category="mask/compositing", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(display_name="red"), + io.Image.Output(display_name="green"), + io.Image.Output(display_name="blue"), + io.Mask.Output(display_name="alpha") + ], + ) + + @classmethod + def execute(cls, image: torch.Tensor) -> io.NodeOutput: + images = [i[:,:,:3] for i in image] + stacked = torch.stack(images) + reds = stacked[:, :, :, 0:1].repeat(1, 1, 1, 3) + greens = stacked[:, :, :, 1:2].repeat(1, 1, 1, 3) + blues = stacked[:, :, :, 2:3].repeat(1, 1, 1, 3) + alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] + return io.NodeOutput(reds, greens, blues, 1.0 - torch.stack(alphas)) + + class JoinImageWithAlpha(io.ComfyNode): @classmethod def define_schema(cls): @@ -217,6 +247,7 @@ class CompositingExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ PorterDuffImageComposite, + SplitImageChannels, SplitImageWithAlpha, JoinImageWithAlpha, ]