Add 10-bit video support (#14452)

Create Video gets a bit_depth option (8-bit/10-bit); the selected depth is carried by the video and applied when it gets encoded. Save Video and Video Slice now keep the source bit depth instead of always quantizing to 8-bit, so 10-bit videos stay 10-bit. 10-bit uses h264 with the yuv420p10le pixel format,so there's no new codec or container.

Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
Alexander Piskun 2026-06-13 16:05:25 +03:00 committed by GitHub
parent 7277d99d3a
commit fe54b5e955
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 169 additions and 12 deletions

View File

@ -27,10 +27,13 @@ class VideoInput(ABC):
path: Union[str, IO[bytes]], path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None,
bit_depth: int | None = None,
): ):
""" """
Abstract method to save the video input to a file. Abstract method to save the video input to a file.
bit_depth selects the encoded bit depth; None keeps the video's native depth.
""" """
pass pass
@ -83,6 +86,14 @@ class VideoInput(ABC):
components = self.get_components() components = self.get_components()
return components.images.shape[2], components.images.shape[1] return components.images.shape[2], components.images.shape[1]
def get_bit_depth(self) -> int:
"""
Returns the bit depth of the video (e.g. 8 or 10).
Default implementation returns 8; subclasses report their real depth.
"""
return 8
def get_duration(self) -> float: def get_duration(self) -> float:
""" """
Returns the duration of the video in seconds. Returns the duration of the video in seconds.

View File

@ -52,6 +52,12 @@ def get_open_write_kwargs(
return open_kwargs return open_kwargs
def video_stream_bit_depth(stream) -> int:
if stream is None or stream.format is None or not stream.format.components:
return 8
return max(component.bits for component in stream.format.components)
class VideoFromFile(VideoInput): class VideoFromFile(VideoInput):
""" """
Class representing video input from a file. Class representing video input from a file.
@ -97,6 +103,13 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'") raise ValueError(f"No video stream found in file '{self.__file}'")
def get_bit_depth(self) -> int:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode="r") as container:
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
return video_stream_bit_depth(video_stream)
def get_duration(self) -> float: def get_duration(self) -> float:
""" """
Returns the duration of the video in seconds. Returns the duration of the video in seconds.
@ -377,25 +390,32 @@ class VideoFromFile(VideoInput):
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
bit_depth: int | None = None,
): ):
if isinstance(self.__file, io.BytesIO): if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container: with av.open(self.__file, mode='r') as container:
container_format = container.format.name container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
video_encoding = video_stream.codec.name if video_stream is not None else None
source_bit_depth = video_stream_bit_depth(video_stream)
reuse_streams = True reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","): if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False reuse_streams = False
if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth:
reuse_streams = False
if self.__start_time or self.__duration: if self.__start_time or self.__duration:
reuse_streams = False reuse_streams = False
if not reuse_streams: if not reuse_streams:
if bit_depth is None:
bit_depth = source_bit_depth
components = self.get_components_internal(container) components = self.get_components_internal(container)
video = VideoFromComponents(components) video = VideoFromComponents(components)
return video.save_to( return video.save_to(
path, format=format, codec=codec, metadata=metadata path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth,
) )
streams = container.streams streams = container.streams
@ -451,8 +471,10 @@ class VideoFromComponents(VideoInput):
Class representing video input from tensors. Class representing video input from tensors.
""" """
def __init__(self, components: VideoComponents): def __init__(self, components: VideoComponents, bit_depth: int = 8):
self.__components = components self.__components = components
# Tensor components have no inherent bit depth; this is the depth used when encoding.
self.__bit_depth = bit_depth
def get_components(self) -> VideoComponents: def get_components(self) -> VideoComponents:
return VideoComponents( return VideoComponents(
@ -461,18 +483,26 @@ class VideoFromComponents(VideoInput):
frame_rate=self.__components.frame_rate, frame_rate=self.__components.frame_rate,
) )
def get_bit_depth(self) -> int:
return self.__bit_depth
def save_to( def save_to(
self, self,
path: str, path: str,
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
bit_depth: int | None = None,
): ):
"""Save the video to a file path or BytesIO buffer.""" """Save the video to a file path or BytesIO buffer."""
if format != VideoContainer.AUTO and format != VideoContainer.MP4: if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now") raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264: if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now") raise ValueError("Only H264 codec is supported for now")
# None means "use the depth this video was created with" (CreateVideo's choice).
if bit_depth is None:
bit_depth = self.__bit_depth
is_10bit = bit_depth >= 10
extra_kwargs = {} extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value extra_kwargs["format"] = format.value
@ -488,10 +518,11 @@ class VideoFromComponents(VideoInput):
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream # Create a video stream
pix_fmt = "yuv420p10le" if is_10bit else "yuv420p"
video_stream = output.add_stream('h264', rate=frame_rate) video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2] video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1] video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p' video_stream.pix_fmt = pix_fmt
# Create an audio stream # Create an audio stream
audio_sample_rate = 1 audio_sample_rate = 1
@ -505,9 +536,14 @@ class VideoFromComponents(VideoInput):
# Encode video # Encode video
for i, frame in enumerate(self.__components.images): for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) if is_10bit:
frame = av.VideoFrame.from_ndarray(img, format='rgb24') # 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format="rgb48le")
else:
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format=pix_fmt)
packet = video_stream.encode(frame) packet = video_stream.encode(frame)
output.mux(packet) output.mux(packet)

View File

@ -134,6 +134,17 @@ class CreateVideo(io.ComfyNode):
io.Image.Input("images", tooltip="The images to create a video from."), io.Image.Input("images", tooltip="The images to create a video from."),
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
io.Int.Input(
"bit_depth",
min=8,
max=10,
default=8,
step=2,
tooltip="Bit depth of the created video. 10-bit keeps smoother gradients with less"
" banding, but some players and downstream nodes may not support it.",
optional=True,
display_mode=io.NumberDisplay.number,
),
], ],
outputs=[ outputs=[
io.Video.Output(), io.Video.Output(),
@ -141,9 +152,14 @@ class CreateVideo(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput: def execute(
cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None, bit_depth: int = 8,
) -> io.NodeOutput:
return io.NodeOutput( return io.NodeOutput(
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) InputImpl.VideoFromComponents(
Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)),
bit_depth=bit_depth,
)
) )
class GetVideoComponents(io.ComfyNode): class GetVideoComponents(io.ComfyNode):
@ -154,7 +170,7 @@ class GetVideoComponents(io.ComfyNode):
search_aliases=["extract frames", "split video", "video to images", "demux"], search_aliases=["extract frames", "split video", "video to images", "demux"],
display_name="Get Video Components", display_name="Get Video Components",
category="video", category="video",
description="Extracts all components from a video: frames, audio, and framerate.", description="Extracts all components from a video: frames, audio, framerate, and bit depth.",
inputs=[ inputs=[
io.Video.Input("video", tooltip="The video to extract components from."), io.Video.Input("video", tooltip="The video to extract components from."),
], ],
@ -162,13 +178,14 @@ class GetVideoComponents(io.ComfyNode):
io.Image.Output(display_name="images"), io.Image.Output(display_name="images"),
io.Audio.Output(display_name="audio"), io.Audio.Output(display_name="audio"),
io.Float.Output(display_name="fps"), io.Float.Output(display_name="fps"),
io.Int.Output(display_name="bit_depth"),
], ],
) )
@classmethod @classmethod
def execute(cls, video: Input.Video) -> io.NodeOutput: def execute(cls, video: Input.Video) -> io.NodeOutput:
components = video.get_components() components = video.get_components()
return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) return io.NodeOutput(components.images, components.audio, float(components.frame_rate), video.get_bit_depth())
class LoadVideo(io.ComfyNode): class LoadVideo(io.ComfyNode):

View File

@ -0,0 +1,93 @@
import pytest
import torch
import av
import numpy as np
from fractions import Fraction
from comfy_api.latest._input_impl.video_types import VideoFromFile, VideoFromComponents
from comfy_api.latest._util.video_types import VideoComponents
@pytest.fixture(scope="module")
def gradient_components():
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
width, height, frames = 64, 64, 3
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
@pytest.fixture(scope="module")
def src8(gradient_components, tmp_path_factory):
"""8-bit h264 mp4 (Create Video default)"""
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
VideoFromComponents(gradient_components).save_to(path)
return path
@pytest.fixture(scope="module")
def src10(gradient_components, tmp_path_factory):
"""10-bit h264 mp4 (Create Video with bit_depth=10)"""
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
VideoFromComponents(gradient_components, bit_depth=10).save_to(path)
return path
def probe(path):
"""(codec, pix_fmt, bit_depth) of the first video stream"""
with av.open(path) as container:
stream = container.streams.video[0]
return (stream.codec.name, stream.format.name, max(c.bits for c in stream.format.components))
def decoded_levels(path):
"""Unique tonal levels in the first decoded frame (banding measure)"""
with av.open(path) as container:
frame = next(container.decode(container.streams.video[0]))
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
def video_packet_bytes(path):
"""Raw video packet payloads; identical to the source's only for a true remux"""
with av.open(path) as container:
return [bytes(p) for p in container.demux(container.streams.video[0]) if p.size]
def test_create_video_bit_depth(src8, src10):
"""Create Video's bit_depth picks the encoded depth (default 8-bit); 10-bit reduces banding"""
assert probe(src8) == ("h264", "yuv420p", 8)
assert probe(src10) == ("h264", "yuv420p10le", 10)
assert decoded_levels(src10) > 2 * decoded_levels(src8)
def test_save_auto_keeps_source_depth(src8, src10, tmp_path):
"""Save Video (no bit_depth = auto) stream-copies the source, preserving its depth byte-for-byte"""
for name, src in [("p8", src8), ("p10", src10)]:
path = str(tmp_path / f"{name}.mp4")
VideoFromFile(src).save_to(path)
assert probe(path) == probe(src)
assert video_packet_bytes(path) == video_packet_bytes(src)
def test_save_explicit_depth_reencodes(src8, src10, tmp_path):
"""An explicit bit_depth different from the source forces a re-encode to that depth"""
down = str(tmp_path / "down8.mp4")
VideoFromFile(src10).save_to(down, bit_depth=8)
assert probe(down) == ("h264", "yuv420p", 8)
up = str(tmp_path / "up10.mp4")
VideoFromFile(src8).save_to(up, bit_depth=10)
assert probe(up) == ("h264", "yuv420p10le", 10)
def test_trim_keeps_source_depth(src10, tmp_path):
"""Video Slice re-encodes (trim) but preserves the source's 10-bit depth"""
path = str(tmp_path / "trim.mp4")
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
assert probe(path) == ("h264", "yuv420p10le", 10)
def test_get_bit_depth(gradient_components, src8, src10):
"""get_bit_depth reports a video's depth (backs the Get Video Components output)"""
assert VideoFromFile(src8).get_bit_depth() == 8
assert VideoFromFile(src10).get_bit_depth() == 10
assert VideoFromComponents(gradient_components, bit_depth=10).get_bit_depth() == 10
assert VideoFromComponents(gradient_components).get_bit_depth() == 8