ComfyUI/tests-unit/comfy_api_test/video_bit_depth_test.py
bigcat88 87790af8a7
Add 10-bit video support to Save Video
Save Video gets a bit_depth widget (auto/8-bit/10-bit). 'auto' preserves
the source file's bit depth when re-encoding; 10-bit encodes h264
yuv420p10le from 16-bit RGB frames so float-precision sources keep their
gradients instead of being quantized to 8-bit.

Video inputs can declare 10-bit support via Video.Input(accepts={"depth": 10}).
At input binding, videos bound to inputs without the declaration are
replaced with a copy whose saved files default to 8-bit, so existing nodes keep producing 8-bit files no matter the
source depth. SaveVideo and VideoSlice declare support, so trimming a
10-bit video and saving it keeps 10-bit.

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-06-13 09:51:21 +03:00

239 lines
9.7 KiB
Python

import pytest
import torch
import av
import logging
import numpy as np
from fractions import Fraction
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
from comfy_api.latest._input_impl.video_types import apply_video_input_accepts
from comfy_api.util.video_types import VideoComponents
from comfy_api.latest._util.video_types import VideoBitDepth
DECLARED = {"accepts": {"depth": 10}}
@pytest.fixture(scope="module")
def gradient_components():
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
width, height, frames = 64, 64, 3
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
@pytest.fixture(scope="module")
def src8(gradient_components, tmp_path_factory):
"""8-bit h264 mp4 source file"""
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
VideoFromComponents(gradient_components).save_to(path)
return path
@pytest.fixture(scope="module")
def src10(gradient_components, tmp_path_factory):
"""10-bit h264 mp4 source file"""
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
VideoFromComponents(gradient_components).save_to(path, bit_depth=VideoBitDepth.BIT_10)
return path
def probe(path):
"""Return (codec, pix_fmt, bit_depth) of the first video stream"""
with av.open(path) as container:
stream = container.streams.video[0]
return (
stream.codec.name,
stream.format.name,
max(component.bits for component in stream.format.components),
)
def decoded_levels(path):
"""Unique tonal levels in the first decoded frame (banding measure)"""
with av.open(path) as container:
frame = next(container.decode(container.streams.video[0]))
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
def video_packet_bytes(path):
"""Raw video packet payloads; identical to the source's only for a true remux"""
with av.open(path) as container:
return [bytes(packet) for packet in container.demux(container.streams.video[0]) if packet.size]
def test_components_save_bit_depths(src8, src10):
"""Default save stays 8-bit h264; 10-bit keeps h264 and clearly reduces banding"""
assert probe(src8) == ("h264", "yuv420p", 8)
assert probe(src10) == ("h264", "yuv420p10le", 10)
assert decoded_levels(src10) > 2 * decoded_levels(src8)
def test_components_unsupported_codec_raises(gradient_components, tmp_path):
with pytest.raises(ValueError, match="H264"):
VideoFromComponents(gradient_components).save_to(str(tmp_path / "x.mp4"), codec="vp9")
def test_bit_depth_enum():
assert VideoBitDepth.as_input() == ["auto", "8-bit", "10-bit"]
assert [d.bits() for d in VideoBitDepth] == [None, 8, 10]
def test_10bit_source_remuxes_untouched(src10, tmp_path):
"""auto and a cap of 10 both keep a 10-bit stream untouched"""
for name, video in [("auto", VideoFromFile(src10)), ("cap10", VideoFromFile(src10).with_bit_depth_cap(10))]:
path = str(tmp_path / f"{name}.mp4")
video.save_to(path)
assert probe(path) == ("h264", "yuv420p10le", 10)
assert video_packet_bytes(path) == video_packet_bytes(src10)
def test_8bit_source_remuxes_on_8bit_request(src8, tmp_path):
"""Neither explicit 8-bit nor a cap of 8 re-encodes an already 8-bit source"""
for name, save in [
("explicit", lambda p: VideoFromFile(src8).save_to(p, bit_depth="8-bit")),
("capped", lambda p: VideoFromFile(src8).with_bit_depth_cap(8).save_to(p)),
]:
path = str(tmp_path / f"{name}.mp4")
save(path)
assert video_packet_bytes(path) == video_packet_bytes(src8)
def test_trim_keeps_source_depth(src10, tmp_path):
"""A re-encode forced by trimming preserves the source's 10-bit depth"""
path = str(tmp_path / "trim.mp4")
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
assert probe(path) == ("h264", "yuv420p10le", 10)
def test_explicit_depth_mismatch_forces_reencode(src8, src10, tmp_path):
"""An explicit depth that differs from the source's re-encodes instead of remuxing"""
down = str(tmp_path / "down8.mp4")
VideoFromFile(src10).save_to(down, bit_depth=VideoBitDepth.BIT_8)
assert probe(down) == ("h264", "yuv420p", 8)
up = str(tmp_path / "up10.mp4")
VideoFromFile(src8).save_to(up, bit_depth=VideoBitDepth.BIT_10)
assert probe(up) == ("h264", "yuv420p10le", 10)
def test_bit_depth_cap(src10, tmp_path):
"""A cap of 8 makes saves default to 8-bit (also through as_trimmed), but an
explicit request wins, and tensor access keeps full precision"""
capped = VideoFromFile(src10).with_bit_depth_cap(8)
path = str(tmp_path / "capped.mp4")
capped.save_to(path)
assert probe(path) == ("h264", "yuv420p", 8)
trimmed = str(tmp_path / "trimmed.mp4")
capped.as_trimmed(0, 1 / 30, strict_duration=False).save_to(trimmed)
assert probe(trimmed) == ("h264", "yuv420p", 8)
explicit = str(tmp_path / "explicit10.mp4")
capped.save_to(explicit, bit_depth=VideoBitDepth.BIT_10)
assert probe(explicit) == ("h264", "yuv420p10le", 10)
images = capped.get_components().images
assert images.dtype == torch.float32
assert len(torch.unique(images[0, :, :, 0])) > 30 # ~13 levels if quantized to 8-bit
def test_accepts_binding_policy(gradient_components, src10, tmp_path):
"""Undeclared inputs get an 8-bit-capped copy of file videos; declared inputs
get uncapped videos; everything else passes through untouched"""
video = VideoFromFile(src10)
# undeclared input: capped copy that saves 8-bit
[capped] = apply_video_input_accepts([video], {"tooltip": "x"})
assert type(capped) is VideoFromFile and capped is not video
bound = str(tmp_path / "bound.mp4")
capped.save_to(bound)
assert probe(bound) == ("h264", "yuv420p", 8)
# declared input: original passes through; a cap from an earlier binding is lifted
assert apply_video_input_accepts([video], DECLARED)[0] is video
[lifted] = apply_video_input_accepts([capped], DECLARED)
lifted_path = str(tmp_path / "lifted.mp4")
lifted.save_to(lifted_path)
assert probe(lifted_path) == ("h264", "yuv420p10le", 10)
# declaring depth 8 is the same as not declaring
assert apply_video_input_accepts([video], {"accepts": {"depth": 8}})[0] is not video
# subclasses, component videos, custom implementations, and non-videos pass through
from comfy_api.latest._input import VideoInput as VideoInputABC
class SubVideo(VideoFromFile):
pass
class CustomVideo(VideoInputABC):
def get_components(self):
raise NotImplementedError
def save_to(self, path, format=None, codec=None, metadata=None):
raise NotImplementedError
def as_trimmed(self, start_time=None, duration=None, strict_duration=False):
return self
passthrough = [SubVideo(src10), VideoFromComponents(gradient_components), CustomVideo(), "not a video", None]
assert apply_video_input_accepts(passthrough, None) == passthrough
def test_accepts_declaration():
"""Video.Input validates and serializes accepts; SaveVideo and VideoSlice declare it"""
from comfy_api.latest import io
import comfy_extras.nodes_video as nv
from comfy_execution.graph import get_input_info
assert io.Video.Input("video", accepts={"depth": 10}).as_dict()["accepts"] == {"depth": 10}
assert "accepts" not in io.Video.Input("video").as_dict()
with pytest.raises(ValueError, match="Unsupported keys"):
io.Video.Input("video", accepts={"codec": "h264"})
with pytest.raises(ValueError, match="must be 8 or 10"):
io.Video.Input("video", accepts={"depth": 12})
for node in (nv.SaveVideo, nv.VideoSlice):
_, _, info = get_input_info(node, "video", node.INPUT_TYPES())
assert info.get("accepts") == {"depth": 10}, node
def test_save_video_node_bit_depth_handling(tmp_path, monkeypatch, caplog):
"""SaveVideo forwards bit_depth to a source that accepts it (the file is really 10-bit),
and for a legacy save_to that predates the parameter it warns and saves anyway instead of raising TypeError"""
import comfy_extras.nodes_video as nv
from comfy_api.latest._io import HiddenHolder
from comfy_api.latest._input import VideoInput as VideoInputABC
monkeypatch.setattr(nv.folder_paths, "get_output_directory", lambda: str(tmp_path))
monkeypatch.setattr(nv.SaveVideo, "hidden", HiddenHolder.from_dict(None))
class LegacyVideo(VideoInputABC):
def get_dimensions(self):
return 16, 16
def get_components(self):
raise NotImplementedError
def save_to(self, path, format=None, codec=None, metadata=None): # no bit_depth
with open(path, "wb") as f:
f.write(b"data")
def as_trimmed(self, start_time=None, duration=None, strict_duration=False):
return self
# legacy source: an explicit 10-bit request must not crash; it warns and still saves
with caplog.at_level(logging.WARNING):
nv.SaveVideo.execute(LegacyVideo(), "legacy", "auto", "auto", bit_depth="10-bit")
assert "does not support bit_depth" in caplog.text
assert list(tmp_path.glob("legacy*"))
# supporting source: bit_depth reaches save_to, so the file really is 10-bit
ramp = torch.linspace(0.25, 0.30, 64).view(1, 1, 64, 1).expand(3, 64, 64, 3).contiguous()
nv.SaveVideo.execute(
VideoFromComponents(VideoComponents(images=ramp, frame_rate=Fraction(30))),
"supported", "auto", "auto", bit_depth="10-bit",
)
outs = list(tmp_path.glob("supported*.mp4"))
assert len(outs) == 1
assert probe(str(outs[0])) == ("h264", "yuv420p10le", 10)