diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 294ad425e..74a9bc167 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents -from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D +from ._util import VideoCodec, VideoContainer, VideoBitDepth, VideoComponents, MESH, VOXEL, SPLAT, File3D from . import _io_public as io from . import _ui_public as ui from comfy_execution.utils import get_executing_context @@ -140,6 +140,7 @@ class InputImpl: class Types: VideoCodec = VideoCodec VideoContainer = VideoContainer + VideoBitDepth = VideoBitDepth VideoComponents = VideoComponents MESH = MESH VOXEL = VOXEL diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index 8fff52c16..ba51c8b6f 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -4,7 +4,7 @@ from fractions import Fraction from typing import Optional, Union, IO import io import av -from .._util import VideoContainer, VideoCodec, VideoComponents +from .._util import VideoContainer, VideoCodec, VideoBitDepth, VideoComponents class VideoInput(ABC): """ @@ -27,7 +27,8 @@ class VideoInput(ABC): path: Union[str, IO[bytes]], format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None + metadata: Optional[dict] = None, + bit_depth: VideoBitDepth = VideoBitDepth.AUTO, ): """ Abstract method to save the video input to a file. diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 4a12ff9c1..5ed2e270f 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -10,7 +10,7 @@ import json import numpy as np import math import torch -from .._util import VideoContainer, VideoCodec, VideoComponents +from .._util import VideoContainer, VideoCodec, VideoBitDepth, VideoComponents import logging @@ -52,12 +52,19 @@ def get_open_write_kwargs( return open_kwargs +def video_stream_bit_depth(stream) -> int: + """Best-effort bit depth of a video stream's pixel format; defaults to 8.""" + if stream is None or stream.format is None or not stream.format.components: + return 8 + return max(component.bits for component in stream.format.components) + + class VideoFromFile(VideoInput): """ Class representing video input from a file. """ - def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0): + def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0, bit_depth_cap: int | None = None): """ Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object containing the file contents. @@ -65,6 +72,18 @@ class VideoFromFile(VideoInput): self.__file = file self.__start_time = start_time self.__duration = duration + self.__bit_depth_cap = bit_depth_cap + + def with_bit_depth_cap(self, bit_depth_cap: Optional[int]) -> "VideoFromFile": + """A copy of this video (sharing the same source) whose saved files default to the capped bit depth. + + Returns self when the cap is already in place; None lifts the cap. + """ + if bit_depth_cap == self.__bit_depth_cap: + return self + return VideoFromFile( + self.__file, start_time=self.__start_time, duration=self.__duration, bit_depth_cap=bit_depth_cap + ) def get_stream_source(self) -> str | io.BytesIO: """ @@ -377,25 +396,35 @@ class VideoFromFile(VideoInput): format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, + bit_depth: VideoBitDepth = VideoBitDepth.AUTO, ): + bit_depth = VideoBitDepth(bit_depth) + if bit_depth == VideoBitDepth.AUTO and self.__bit_depth_cap is not None and self.__bit_depth_cap < 10: + bit_depth = VideoBitDepth.BIT_8 if isinstance(self.__file, io.BytesIO): self.__file.seek(0) # Reset the BytesIO object to the beginning with av.open(self.__file, mode='r') as container: container_format = container.format.name - video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None + video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None + video_encoding = video_stream.codec.name if video_stream is not None else None + source_bit_depth = video_stream_bit_depth(video_stream) reuse_streams = True if format != VideoContainer.AUTO and format not in container_format.split(","): reuse_streams = False if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: reuse_streams = False + if bit_depth != VideoBitDepth.AUTO and video_encoding is not None and bit_depth.bits() != source_bit_depth: + reuse_streams = False if self.__start_time or self.__duration: reuse_streams = False if not reuse_streams: + if bit_depth == VideoBitDepth.AUTO: + bit_depth = VideoBitDepth.BIT_10 if source_bit_depth >= 10 else VideoBitDepth.BIT_8 components = self.get_components_internal(container) video = VideoFromComponents(components) return video.save_to( - path, format=format, codec=codec, metadata=metadata + path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth ) streams = container.streams @@ -440,6 +469,7 @@ class VideoFromFile(VideoInput): self.get_stream_source(), start_time=start_time + self.__start_time, duration=duration, + bit_depth_cap=self.__bit_depth_cap, ) if trimmed.get_duration() < duration and strict_duration: return None @@ -467,12 +497,15 @@ class VideoFromComponents(VideoInput): format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, + bit_depth: VideoBitDepth = VideoBitDepth.AUTO, ): """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: raise ValueError("Only H264 codec is supported for now") + # AUTO is 8-bit: tensor components have no source bit depth to preserve. + is_10bit = VideoBitDepth(bit_depth) == VideoBitDepth.BIT_10 extra_kwargs = {} if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value @@ -488,10 +521,11 @@ class VideoFromComponents(VideoInput): frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) # Create a video stream + pix_fmt = 'yuv420p10le' if is_10bit else 'yuv420p' video_stream = output.add_stream('h264', rate=frame_rate) video_stream.width = self.__components.images.shape[2] video_stream.height = self.__components.images.shape[1] - video_stream.pix_fmt = 'yuv420p' + video_stream.pix_fmt = pix_fmt # Create an audio stream audio_sample_rate = 1 @@ -505,9 +539,14 @@ class VideoFromComponents(VideoInput): # Encode video for i, frame in enumerate(self.__components.images): - img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) - frame = av.VideoFrame.from_ndarray(img, format='rgb24') - frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 + if is_10bit: + # 16-bit RGB keeps float precision through the conversion to 10-bit YUV. + img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3) + frame = av.VideoFrame.from_ndarray(img, format='rgb48le') + else: + img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + frame = frame.reformat(format=pix_fmt) packet = video_stream.encode(frame) output.mux(packet) @@ -534,3 +573,19 @@ class VideoFromComponents(VideoInput): return None #TODO Consider tracking duration and trimming at time of save? return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration) + + +def apply_video_input_accepts(values: list, input_info: dict | None) -> list: + """Apply a VIDEO input's `accepts` declaration to its bound values. + + Inputs declaring `accepts={"depth": 10}` receive uncapped videos. + For the rest, file-backed videos are replaced with copies that save as 8-bit by default, + so existing nodes keep producing 8-bit files. + VideoFromFile subclasses and other VideoInput implementations own their depth behavior and pass through unchanged. + """ + accepts = (input_info or {}).get("accepts") or {} + cap = None if accepts.get("depth", 8) >= 10 else 8 + return [ + value.with_bit_depth_cap(cap) if type(value) is VideoFromFile else value + for value in values + ] diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 012fae3ac..04a631626 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -662,6 +662,26 @@ class Video(ComfyTypeIO): if TYPE_CHECKING: Type = VideoInput + class Input(Input): + """Video input socket. + + `accepts` declares which video properties the node handles itself; only "depth" (8 or 10) is supported for now, + e.g. `accepts={"depth": 10}`. Inputs without it receive videos whose saved files are capped to 8-bit. + """ + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + extra_dict=None, raw_link: bool=None, advanced: bool=None, accepts: dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) + if accepts is not None: + unknown_keys = set(accepts) - {"depth"} + if unknown_keys: + raise ValueError(f"Unsupported keys in Video.Input accepts: {sorted(unknown_keys)}") + if "depth" in accepts and accepts["depth"] not in (8, 10): + raise ValueError("Video.Input accepts['depth'] must be 8 or 10") + self.accepts = accepts + + def as_dict(self): + return super().as_dict() | prune_dict({"accepts": self.accepts}) + @comfytype(io_type="SVG") class SVG(ComfyTypeIO): Type = _SVG diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index b27f5a97e..702e0606a 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,4 +1,4 @@ -from .video_types import VideoContainer, VideoCodec, VideoComponents +from .video_types import VideoContainer, VideoCodec, VideoBitDepth, VideoComponents from .geometry_types import VOXEL, MESH, SPLAT, File3D from .image_types import SVG @@ -6,6 +6,7 @@ __all__ = [ # Utility Types "VideoContainer", "VideoCodec", + "VideoBitDepth", "VideoComponents", "VOXEL", "MESH", diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py index 6c9d6a526..4d8705c37 100644 --- a/comfy_api/latest/_util/video_types.py +++ b/comfy_api/latest/_util/video_types.py @@ -15,6 +15,23 @@ class VideoCodec(str, Enum): """ return [member.value for member in cls] + +class VideoBitDepth(str, Enum): + AUTO = "auto" + BIT_8 = "8-bit" + BIT_10 = "10-bit" + + @classmethod + def as_input(cls) -> list[str]: + """Returns a list of bit depth names that can be used as node input.""" + return [member.value for member in cls] + + def bits(self) -> Optional[int]: + """Returns the numeric bit depth, or None for AUTO.""" + if self == VideoBitDepth.AUTO: + return None + return int(self.value.split("-")[0]) + class VideoContainer(str, Enum): AUTO = "auto" MP4 = "mp4" diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 6f6c416a6..4b1402cb5 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -3,6 +3,8 @@ import av import torch import folder_paths import json +import inspect +import logging from typing import Optional from typing_extensions import override from fractions import Fraction @@ -71,6 +73,15 @@ class SaveWEBM(io.ComfyNode): return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) + +def _save_to_supports_bit_depth(video) -> bool: + try: + params = inspect.signature(video.save_to).parameters + except (TypeError, ValueError): + return True # not introspectable; assume the current contract + return "bit_depth" in params or any(p.kind is inspect.Parameter.VAR_KEYWORD for p in params.values()) + + class SaveVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -82,17 +93,26 @@ class SaveVideo(io.ComfyNode): essentials_category="Basics", description="Saves the input images to your ComfyUI output directory.", inputs=[ - io.Video.Input("video", tooltip="The video to save."), + io.Video.Input("video", tooltip="The video to save.", accepts={"depth": 10}), io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + io.Combo.Input( + "bit_depth", + options=Types.VideoBitDepth.as_input(), + default="auto", + tooltip="Bit depth used when the video has to be re-encoded." + " 'auto' keeps the bit depth of the source video (videos created from images are saved as 8-bit)." + " 10-bit keeps smoother gradients with less banding, but some players may not support it.", + optional=True, + ), ], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod - def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput: + def execute(cls, video: Input.Video, filename_prefix, format: str, codec, bit_depth: str = "auto") -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, @@ -110,11 +130,22 @@ class SaveVideo(io.ComfyNode): if len(metadata) > 0: saved_metadata = metadata file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" + bit_depth = Types.VideoBitDepth(bit_depth) + save_kwargs = {} + if bit_depth != Types.VideoBitDepth.AUTO: + if _save_to_supports_bit_depth(video): + save_kwargs["bit_depth"] = bit_depth + else: + logging.warning( + "%s.save_to() does not support bit_depth; saving at the source's default depth.", + type(video).__name__, + ) video.save_to( os.path.join(full_output_folder, file), format=Types.VideoContainer(format), codec=codec, - metadata=saved_metadata + metadata=saved_metadata, + **save_kwargs, ) return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) @@ -226,7 +257,7 @@ class VideoSlice(io.ComfyNode): category="video", essentials_category="Video Tools", inputs=[ - io.Video.Input("video"), + io.Video.Input("video", accepts={"depth": 10}), io.Float.Input( "start_time", default=0.0, diff --git a/execution.py b/execution.py index 9e16e451d..8596b18e1 100644 --- a/execution.py +++ b/execution.py @@ -43,6 +43,7 @@ from comfy_execution.utils import CurrentNodeContext from comfy_execution.asset_enrichment import enrich_output_with_assets from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io +from comfy_api.latest._input_impl.video_types import apply_video_input_accepts from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger @@ -164,7 +165,7 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= missing_keys = {} for x in inputs: input_data = inputs[x] - _, input_category, input_info = get_input_info(class_def, x, valid_inputs) + input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) @@ -182,6 +183,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= mark_missing() continue obj = cached.outputs[output_index] + if input_type == io.Video.io_type: + obj = apply_video_input_accepts(obj, input_info) input_data_all[x] = obj elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS): input_data_all[x] = [input_data] diff --git a/tests-unit/comfy_api_test/video_bit_depth_test.py b/tests-unit/comfy_api_test/video_bit_depth_test.py new file mode 100644 index 000000000..bac0087e5 --- /dev/null +++ b/tests-unit/comfy_api_test/video_bit_depth_test.py @@ -0,0 +1,238 @@ +import pytest +import torch +import av +import logging +import numpy as np +from fractions import Fraction +from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents +from comfy_api.latest._input_impl.video_types import apply_video_input_accepts +from comfy_api.util.video_types import VideoComponents +from comfy_api.latest._util.video_types import VideoBitDepth + +DECLARED = {"accepts": {"depth": 10}} + + +@pytest.fixture(scope="module") +def gradient_components(): + """Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth""" + width, height, frames = 64, 64, 3 + ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3) + return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30)) + + +@pytest.fixture(scope="module") +def src8(gradient_components, tmp_path_factory): + """8-bit h264 mp4 source file""" + path = str(tmp_path_factory.mktemp("video") / "src8.mp4") + VideoFromComponents(gradient_components).save_to(path) + return path + + +@pytest.fixture(scope="module") +def src10(gradient_components, tmp_path_factory): + """10-bit h264 mp4 source file""" + path = str(tmp_path_factory.mktemp("video") / "src10.mp4") + VideoFromComponents(gradient_components).save_to(path, bit_depth=VideoBitDepth.BIT_10) + return path + + +def probe(path): + """Return (codec, pix_fmt, bit_depth) of the first video stream""" + with av.open(path) as container: + stream = container.streams.video[0] + return ( + stream.codec.name, + stream.format.name, + max(component.bits for component in stream.format.components), + ) + + +def decoded_levels(path): + """Unique tonal levels in the first decoded frame (banding measure)""" + with av.open(path) as container: + frame = next(container.decode(container.streams.video[0])) + return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0])) + + +def video_packet_bytes(path): + """Raw video packet payloads; identical to the source's only for a true remux""" + with av.open(path) as container: + return [bytes(packet) for packet in container.demux(container.streams.video[0]) if packet.size] + + +def test_components_save_bit_depths(src8, src10): + """Default save stays 8-bit h264; 10-bit keeps h264 and clearly reduces banding""" + assert probe(src8) == ("h264", "yuv420p", 8) + assert probe(src10) == ("h264", "yuv420p10le", 10) + assert decoded_levels(src10) > 2 * decoded_levels(src8) + + +def test_components_unsupported_codec_raises(gradient_components, tmp_path): + with pytest.raises(ValueError, match="H264"): + VideoFromComponents(gradient_components).save_to(str(tmp_path / "x.mp4"), codec="vp9") + + +def test_bit_depth_enum(): + assert VideoBitDepth.as_input() == ["auto", "8-bit", "10-bit"] + assert [d.bits() for d in VideoBitDepth] == [None, 8, 10] + + +def test_10bit_source_remuxes_untouched(src10, tmp_path): + """auto and a cap of 10 both keep a 10-bit stream untouched""" + for name, video in [("auto", VideoFromFile(src10)), ("cap10", VideoFromFile(src10).with_bit_depth_cap(10))]: + path = str(tmp_path / f"{name}.mp4") + video.save_to(path) + assert probe(path) == ("h264", "yuv420p10le", 10) + assert video_packet_bytes(path) == video_packet_bytes(src10) + + +def test_8bit_source_remuxes_on_8bit_request(src8, tmp_path): + """Neither explicit 8-bit nor a cap of 8 re-encodes an already 8-bit source""" + for name, save in [ + ("explicit", lambda p: VideoFromFile(src8).save_to(p, bit_depth="8-bit")), + ("capped", lambda p: VideoFromFile(src8).with_bit_depth_cap(8).save_to(p)), + ]: + path = str(tmp_path / f"{name}.mp4") + save(path) + assert video_packet_bytes(path) == video_packet_bytes(src8) + + +def test_trim_keeps_source_depth(src10, tmp_path): + """A re-encode forced by trimming preserves the source's 10-bit depth""" + path = str(tmp_path / "trim.mp4") + VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path) + assert probe(path) == ("h264", "yuv420p10le", 10) + + +def test_explicit_depth_mismatch_forces_reencode(src8, src10, tmp_path): + """An explicit depth that differs from the source's re-encodes instead of remuxing""" + down = str(tmp_path / "down8.mp4") + VideoFromFile(src10).save_to(down, bit_depth=VideoBitDepth.BIT_8) + assert probe(down) == ("h264", "yuv420p", 8) + + up = str(tmp_path / "up10.mp4") + VideoFromFile(src8).save_to(up, bit_depth=VideoBitDepth.BIT_10) + assert probe(up) == ("h264", "yuv420p10le", 10) + + +def test_bit_depth_cap(src10, tmp_path): + """A cap of 8 makes saves default to 8-bit (also through as_trimmed), but an + explicit request wins, and tensor access keeps full precision""" + capped = VideoFromFile(src10).with_bit_depth_cap(8) + + path = str(tmp_path / "capped.mp4") + capped.save_to(path) + assert probe(path) == ("h264", "yuv420p", 8) + + trimmed = str(tmp_path / "trimmed.mp4") + capped.as_trimmed(0, 1 / 30, strict_duration=False).save_to(trimmed) + assert probe(trimmed) == ("h264", "yuv420p", 8) + + explicit = str(tmp_path / "explicit10.mp4") + capped.save_to(explicit, bit_depth=VideoBitDepth.BIT_10) + assert probe(explicit) == ("h264", "yuv420p10le", 10) + + images = capped.get_components().images + assert images.dtype == torch.float32 + assert len(torch.unique(images[0, :, :, 0])) > 30 # ~13 levels if quantized to 8-bit + + +def test_accepts_binding_policy(gradient_components, src10, tmp_path): + """Undeclared inputs get an 8-bit-capped copy of file videos; declared inputs + get uncapped videos; everything else passes through untouched""" + video = VideoFromFile(src10) + + # undeclared input: capped copy that saves 8-bit + [capped] = apply_video_input_accepts([video], {"tooltip": "x"}) + assert type(capped) is VideoFromFile and capped is not video + bound = str(tmp_path / "bound.mp4") + capped.save_to(bound) + assert probe(bound) == ("h264", "yuv420p", 8) + + # declared input: original passes through; a cap from an earlier binding is lifted + assert apply_video_input_accepts([video], DECLARED)[0] is video + [lifted] = apply_video_input_accepts([capped], DECLARED) + lifted_path = str(tmp_path / "lifted.mp4") + lifted.save_to(lifted_path) + assert probe(lifted_path) == ("h264", "yuv420p10le", 10) + + # declaring depth 8 is the same as not declaring + assert apply_video_input_accepts([video], {"accepts": {"depth": 8}})[0] is not video + + # subclasses, component videos, custom implementations, and non-videos pass through + from comfy_api.latest._input import VideoInput as VideoInputABC + + class SubVideo(VideoFromFile): + pass + + class CustomVideo(VideoInputABC): + def get_components(self): + raise NotImplementedError + + def save_to(self, path, format=None, codec=None, metadata=None): + raise NotImplementedError + + def as_trimmed(self, start_time=None, duration=None, strict_duration=False): + return self + + passthrough = [SubVideo(src10), VideoFromComponents(gradient_components), CustomVideo(), "not a video", None] + assert apply_video_input_accepts(passthrough, None) == passthrough + + +def test_accepts_declaration(): + """Video.Input validates and serializes accepts; SaveVideo and VideoSlice declare it""" + from comfy_api.latest import io + import comfy_extras.nodes_video as nv + from comfy_execution.graph import get_input_info + + assert io.Video.Input("video", accepts={"depth": 10}).as_dict()["accepts"] == {"depth": 10} + assert "accepts" not in io.Video.Input("video").as_dict() + with pytest.raises(ValueError, match="Unsupported keys"): + io.Video.Input("video", accepts={"codec": "h264"}) + with pytest.raises(ValueError, match="must be 8 or 10"): + io.Video.Input("video", accepts={"depth": 12}) + + for node in (nv.SaveVideo, nv.VideoSlice): + _, _, info = get_input_info(node, "video", node.INPUT_TYPES()) + assert info.get("accepts") == {"depth": 10}, node + + +def test_save_video_node_bit_depth_handling(tmp_path, monkeypatch, caplog): + """SaveVideo forwards bit_depth to a source that accepts it (the file is really 10-bit), + and for a legacy save_to that predates the parameter it warns and saves anyway instead of raising TypeError""" + import comfy_extras.nodes_video as nv + from comfy_api.latest._io import HiddenHolder + from comfy_api.latest._input import VideoInput as VideoInputABC + + monkeypatch.setattr(nv.folder_paths, "get_output_directory", lambda: str(tmp_path)) + monkeypatch.setattr(nv.SaveVideo, "hidden", HiddenHolder.from_dict(None)) + + class LegacyVideo(VideoInputABC): + def get_dimensions(self): + return 16, 16 + + def get_components(self): + raise NotImplementedError + + def save_to(self, path, format=None, codec=None, metadata=None): # no bit_depth + with open(path, "wb") as f: + f.write(b"data") + + def as_trimmed(self, start_time=None, duration=None, strict_duration=False): + return self + + # legacy source: an explicit 10-bit request must not crash; it warns and still saves + with caplog.at_level(logging.WARNING): + nv.SaveVideo.execute(LegacyVideo(), "legacy", "auto", "auto", bit_depth="10-bit") + assert "does not support bit_depth" in caplog.text + assert list(tmp_path.glob("legacy*")) + + # supporting source: bit_depth reaches save_to, so the file really is 10-bit + ramp = torch.linspace(0.25, 0.30, 64).view(1, 1, 64, 1).expand(3, 64, 64, 3).contiguous() + nv.SaveVideo.execute( + VideoFromComponents(VideoComponents(images=ramp, frame_rate=Fraction(30))), + "supported", "auto", "auto", bit_depth="10-bit", + ) + outs = list(tmp_path.glob("supported*.mp4")) + assert len(outs) == 1 + assert probe(str(outs[0])) == ("h264", "yuv420p10le", 10)