Support loading the alpha channel of videos. (#13564)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

Not exposed in nodes yet.
This commit is contained in:
comfyanonymous 2026-04-25 18:02:58 -07:00 committed by GitHub
parent 5e3f15a830
commit df22bcd5e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 8 deletions

View File

@ -240,19 +240,34 @@ class VideoFromFile(VideoInput):
start_time = self.__start_time
# Get video frames
frames = []
alphas = None
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
image_format = 'gbrpf32le'
for frame in container.decode(video_stream):
if alphas is None:
for comp in frame.format.components:
if comp.is_alpha:
alphas = []
image_format = 'gbrapf32le'
break
if frame.pts < start_pts:
continue
if self.__duration and frame.pts >= end_pts:
break
img = frame.to_ndarray(format='gbrpf32le') # shape: (H, W, 3)
img = torch.from_numpy(img)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
if alphas is None:
frames.append(torch.from_numpy(img))
else:
frames.append(torch.from_numpy(img[..., :-1]))
alphas.append(torch.from_numpy(img[..., -1:]))
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3)
if alphas is not None:
alphas = torch.stack(alphas) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1)
# Get frame rate
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
@ -295,7 +310,7 @@ class VideoFromFile(VideoInput):
})
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from .._input import ImageInput, AudioInput
from .._input import ImageInput, AudioInput, MaskInput
class VideoCodec(str, Enum):
AUTO = "auto"
@ -48,5 +48,4 @@ class VideoComponents:
frame_rate: Fraction
audio: Optional[AudioInput] = None
metadata: Optional[dict] = None
alpha: Optional[MaskInput] = None