diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index 69b0233af..ce2c9e9be 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -1,13 +1,16 @@ +import av +import torch +from av.codec import CodecContext from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bria import ( BriaEditImageRequest, + BriaImageEditResponse, BriaRemoveBackgroundRequest, BriaRemoveBackgroundResponse, BriaRemoveVideoBackgroundRequest, BriaRemoveVideoBackgroundResponse, - BriaImageEditResponse, BriaStatusResponse, InputModerationSettings, ) @@ -316,6 +319,96 @@ class BriaRemoveVideoBackground(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) +def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]: + """Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask. + + VP9 keeps its alpha in a side layer that PyAV's default vp9 decoder drops, so the frames + are decoded with libvpx-vp9. Returns RGB images [B,H,W,3] in 0..1 and a mask [B,H,W] + following the Load Image convention (1 = transparent) for compositing or Save WEBM. + """ + rgb_frames: list[torch.Tensor] = [] + alpha_frames: list[torch.Tensor] = [] + with av.open(video.get_stream_source(), mode="r") as container: + stream = container.streams.video[0] + decoder = CodecContext.create("libvpx-vp9", "r") if stream.codec_context.name == "vp9" else None + for packet in container.demux(stream): + for frame in (decoder.decode(packet) if decoder is not None else packet.decode()): + rgba = torch.from_numpy(frame.to_ndarray(format="rgba")).float() / 255.0 + rgb_frames.append(rgba[..., :3]) + alpha_frames.append(rgba[..., 3]) + images = torch.stack(rgb_frames) if rgb_frames else torch.zeros(0, 0, 0, 3) + mask = (1.0 - torch.stack(alpha_frames)) if alpha_frames else torch.zeros((images.shape[0], 64, 64)) + return images, mask + + +class BriaTransparentVideoBackground(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaTransparentVideoBackground", + display_name="Bria Remove Video Background (Transparent)", + category="partner/video/Bria", + description="Remove the background from a video using Bria and return the cut-out frames " + "plus an alpha mask. Connect both to a compositing node, or feed them to Save WEBM to " + "write a transparent video.", + inputs=[ + IO.Video.Input("video"), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.Image.Output(display_name="images"), + IO.Mask.Output(display_name="mask"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + seed: int, + ) -> IO.NodeOutput: + validate_video_duration(video, max_duration=60.0) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"), + data=BriaRemoveVideoBackgroundRequest( + video=await upload_video_to_comfyapi(cls, video), + background_color="Transparent", + output_container_and_codec="webm_vp9", + seed=seed, + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaRemoveVideoBackgroundResponse, + ) + video_out = await download_url_to_video_output(response.result.video_url) + images, mask = _video_to_images_and_mask(video_out) + return IO.NodeOutput(images, mask) + + class BriaExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -323,6 +416,7 @@ class BriaExtension(ComfyExtension): BriaImageEditNode, BriaRemoveImageBackground, BriaRemoveVideoBackground, + BriaTransparentVideoBackground, ] diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index ae1d826d5..6f6c416a6 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -19,7 +19,7 @@ class SaveWEBM(io.ComfyNode): category="video", is_experimental=True, inputs=[ - io.Image.Input("images"), + io.Image.Input("images", tooltip="RGBA images are saved with their alpha channel as transparency (vp9 codec only)."), io.String.Input("filename_prefix", default="ComfyUI"), io.Combo.Input("codec", options=["vp9", "av1"]), io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), @@ -45,18 +45,25 @@ class SaveWEBM(io.ComfyNode): for x in cls.hidden.extra_pnginfo: container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) + # Save transparency when the images carry an alpha channel (RGBA) and the codec supports it. + # vp9 -> yuva420p; other codecs have no usable alpha path, so the alpha is ignored. + save_alpha = images.shape[-1] == 4 and codec == "vp9" + codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"} stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) stream.width = images.shape[-2] stream.height = images.shape[-3] - stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p" + stream.pix_fmt = "yuva420p" if save_alpha else ("yuv420p10le" if codec == "av1" else "yuv420p") stream.bit_rate = 0 stream.options = {'crf': str(crf)} if codec == "av1": stream.options["preset"] = "6" for frame in images: - frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24") + if save_alpha: + frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :4] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgba") + else: + frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24") for packet in stream.encode(frame): container.mux(packet) container.mux(stream.encode())