From df22bcd5e192ce0b1ae09eaf2e423d0a12cf6638 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 25 Apr 2026 18:02:58 -0700 Subject: [PATCH] Support loading the alpha channel of videos. (#13564) Not exposed in nodes yet. --- comfy_api/latest/_input_impl/video_types.py | 25 ++++++++++++++++----- comfy_api/latest/_util/video_types.py | 5 ++--- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index bd8090635..eb4d3701d 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -240,19 +240,34 @@ class VideoFromFile(VideoInput): start_time = self.__start_time # Get video frames frames = [] + alphas = None start_pts = int(start_time / video_stream.time_base) end_pts = int((start_time + self.__duration) / video_stream.time_base) container.seek(start_pts, stream=video_stream) + image_format = 'gbrpf32le' for frame in container.decode(video_stream): + if alphas is None: + for comp in frame.format.components: + if comp.is_alpha: + alphas = [] + image_format = 'gbrapf32le' + break + if frame.pts < start_pts: continue if self.__duration and frame.pts >= end_pts: break - img = frame.to_ndarray(format='gbrpf32le') # shape: (H, W, 3) - img = torch.from_numpy(img) - frames.append(img) - images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) + img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) + if alphas is None: + frames.append(torch.from_numpy(img)) + else: + frames.append(torch.from_numpy(img[..., :-1])) + alphas.append(torch.from_numpy(img[..., -1:])) + + images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3) + if alphas is not None: + alphas = torch.stack(alphas) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1) # Get frame rate frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1) @@ -295,7 +310,7 @@ class VideoFromFile(VideoInput): }) metadata = container.metadata - return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) + return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata) def get_components(self) -> VideoComponents: if isinstance(self.__file, io.BytesIO): diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py index fd3b5a510..c92477f08 100644 --- a/comfy_api/latest/_util/video_types.py +++ b/comfy_api/latest/_util/video_types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from fractions import Fraction from typing import Optional -from .._input import ImageInput, AudioInput +from .._input import ImageInput, AudioInput, MaskInput class VideoCodec(str, Enum): AUTO = "auto" @@ -48,5 +48,4 @@ class VideoComponents: frame_rate: Fraction audio: Optional[AudioInput] = None metadata: Optional[dict] = None - - + alpha: Optional[MaskInput] = None