mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
Add 10-bit video support to Save Video
Save Video gets a bit_depth widget (auto/8-bit/10-bit). 'auto' preserves
the source file's bit depth when re-encoding; 10-bit encodes h264
yuv420p10le from 16-bit RGB frames so float-precision sources keep their
gradients instead of being quantized to 8-bit.
Video inputs can declare 10-bit support via Video.Input(accepts={"depth": 10}).
At input binding, videos bound to inputs without the declaration are
replaced with a copy whose saved files default to 8-bit, so existing nodes keep producing 8-bit files no matter the
source depth. SaveVideo and VideoSlice declare support, so trimming a
10-bit video and saving it keeps 10-bit.
Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
parent
7277d99d3a
commit
87790af8a7
@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
|||||||
from comfy_api.internal.async_to_sync import create_sync_class
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D
|
from ._util import VideoCodec, VideoContainer, VideoBitDepth, VideoComponents, MESH, VOXEL, SPLAT, File3D
|
||||||
from . import _io_public as io
|
from . import _io_public as io
|
||||||
from . import _ui_public as ui
|
from . import _ui_public as ui
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
@ -140,6 +140,7 @@ class InputImpl:
|
|||||||
class Types:
|
class Types:
|
||||||
VideoCodec = VideoCodec
|
VideoCodec = VideoCodec
|
||||||
VideoContainer = VideoContainer
|
VideoContainer = VideoContainer
|
||||||
|
VideoBitDepth = VideoBitDepth
|
||||||
VideoComponents = VideoComponents
|
VideoComponents = VideoComponents
|
||||||
MESH = MESH
|
MESH = MESH
|
||||||
VOXEL = VOXEL
|
VOXEL = VOXEL
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from fractions import Fraction
|
|||||||
from typing import Optional, Union, IO
|
from typing import Optional, Union, IO
|
||||||
import io
|
import io
|
||||||
import av
|
import av
|
||||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
from .._util import VideoContainer, VideoCodec, VideoBitDepth, VideoComponents
|
||||||
|
|
||||||
class VideoInput(ABC):
|
class VideoInput(ABC):
|
||||||
"""
|
"""
|
||||||
@ -27,7 +27,8 @@ class VideoInput(ABC):
|
|||||||
path: Union[str, IO[bytes]],
|
path: Union[str, IO[bytes]],
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None,
|
||||||
|
bit_depth: VideoBitDepth = VideoBitDepth.AUTO,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Abstract method to save the video input to a file.
|
Abstract method to save the video input to a file.
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import json
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
from .._util import VideoContainer, VideoCodec, VideoBitDepth, VideoComponents
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@ -52,12 +52,19 @@ def get_open_write_kwargs(
|
|||||||
return open_kwargs
|
return open_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def video_stream_bit_depth(stream) -> int:
|
||||||
|
"""Best-effort bit depth of a video stream's pixel format; defaults to 8."""
|
||||||
|
if stream is None or stream.format is None or not stream.format.components:
|
||||||
|
return 8
|
||||||
|
return max(component.bits for component in stream.format.components)
|
||||||
|
|
||||||
|
|
||||||
class VideoFromFile(VideoInput):
|
class VideoFromFile(VideoInput):
|
||||||
"""
|
"""
|
||||||
Class representing video input from a file.
|
Class representing video input from a file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
|
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0, bit_depth_cap: int | None = None):
|
||||||
"""
|
"""
|
||||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||||
containing the file contents.
|
containing the file contents.
|
||||||
@ -65,6 +72,18 @@ class VideoFromFile(VideoInput):
|
|||||||
self.__file = file
|
self.__file = file
|
||||||
self.__start_time = start_time
|
self.__start_time = start_time
|
||||||
self.__duration = duration
|
self.__duration = duration
|
||||||
|
self.__bit_depth_cap = bit_depth_cap
|
||||||
|
|
||||||
|
def with_bit_depth_cap(self, bit_depth_cap: Optional[int]) -> "VideoFromFile":
|
||||||
|
"""A copy of this video (sharing the same source) whose saved files default to the capped bit depth.
|
||||||
|
|
||||||
|
Returns self when the cap is already in place; None lifts the cap.
|
||||||
|
"""
|
||||||
|
if bit_depth_cap == self.__bit_depth_cap:
|
||||||
|
return self
|
||||||
|
return VideoFromFile(
|
||||||
|
self.__file, start_time=self.__start_time, duration=self.__duration, bit_depth_cap=bit_depth_cap
|
||||||
|
)
|
||||||
|
|
||||||
def get_stream_source(self) -> str | io.BytesIO:
|
def get_stream_source(self) -> str | io.BytesIO:
|
||||||
"""
|
"""
|
||||||
@ -377,25 +396,35 @@ class VideoFromFile(VideoInput):
|
|||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
bit_depth: VideoBitDepth = VideoBitDepth.AUTO,
|
||||||
):
|
):
|
||||||
|
bit_depth = VideoBitDepth(bit_depth)
|
||||||
|
if bit_depth == VideoBitDepth.AUTO and self.__bit_depth_cap is not None and self.__bit_depth_cap < 10:
|
||||||
|
bit_depth = VideoBitDepth.BIT_8
|
||||||
if isinstance(self.__file, io.BytesIO):
|
if isinstance(self.__file, io.BytesIO):
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
with av.open(self.__file, mode='r') as container:
|
with av.open(self.__file, mode='r') as container:
|
||||||
container_format = container.format.name
|
container_format = container.format.name
|
||||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||||
|
video_encoding = video_stream.codec.name if video_stream is not None else None
|
||||||
|
source_bit_depth = video_stream_bit_depth(video_stream)
|
||||||
reuse_streams = True
|
reuse_streams = True
|
||||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
|
if bit_depth != VideoBitDepth.AUTO and video_encoding is not None and bit_depth.bits() != source_bit_depth:
|
||||||
|
reuse_streams = False
|
||||||
if self.__start_time or self.__duration:
|
if self.__start_time or self.__duration:
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
|
|
||||||
if not reuse_streams:
|
if not reuse_streams:
|
||||||
|
if bit_depth == VideoBitDepth.AUTO:
|
||||||
|
bit_depth = VideoBitDepth.BIT_10 if source_bit_depth >= 10 else VideoBitDepth.BIT_8
|
||||||
components = self.get_components_internal(container)
|
components = self.get_components_internal(container)
|
||||||
video = VideoFromComponents(components)
|
video = VideoFromComponents(components)
|
||||||
return video.save_to(
|
return video.save_to(
|
||||||
path, format=format, codec=codec, metadata=metadata
|
path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth
|
||||||
)
|
)
|
||||||
|
|
||||||
streams = container.streams
|
streams = container.streams
|
||||||
@ -440,6 +469,7 @@ class VideoFromFile(VideoInput):
|
|||||||
self.get_stream_source(),
|
self.get_stream_source(),
|
||||||
start_time=start_time + self.__start_time,
|
start_time=start_time + self.__start_time,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
|
bit_depth_cap=self.__bit_depth_cap,
|
||||||
)
|
)
|
||||||
if trimmed.get_duration() < duration and strict_duration:
|
if trimmed.get_duration() < duration and strict_duration:
|
||||||
return None
|
return None
|
||||||
@ -467,12 +497,15 @@ class VideoFromComponents(VideoInput):
|
|||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
bit_depth: VideoBitDepth = VideoBitDepth.AUTO,
|
||||||
):
|
):
|
||||||
"""Save the video to a file path or BytesIO buffer."""
|
"""Save the video to a file path or BytesIO buffer."""
|
||||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||||
raise ValueError("Only MP4 format is supported for now")
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||||
raise ValueError("Only H264 codec is supported for now")
|
raise ValueError("Only H264 codec is supported for now")
|
||||||
|
# AUTO is 8-bit: tensor components have no source bit depth to preserve.
|
||||||
|
is_10bit = VideoBitDepth(bit_depth) == VideoBitDepth.BIT_10
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||||
extra_kwargs["format"] = format.value
|
extra_kwargs["format"] = format.value
|
||||||
@ -488,10 +521,11 @@ class VideoFromComponents(VideoInput):
|
|||||||
|
|
||||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||||
# Create a video stream
|
# Create a video stream
|
||||||
|
pix_fmt = 'yuv420p10le' if is_10bit else 'yuv420p'
|
||||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||||
video_stream.width = self.__components.images.shape[2]
|
video_stream.width = self.__components.images.shape[2]
|
||||||
video_stream.height = self.__components.images.shape[1]
|
video_stream.height = self.__components.images.shape[1]
|
||||||
video_stream.pix_fmt = 'yuv420p'
|
video_stream.pix_fmt = pix_fmt
|
||||||
|
|
||||||
# Create an audio stream
|
# Create an audio stream
|
||||||
audio_sample_rate = 1
|
audio_sample_rate = 1
|
||||||
@ -505,9 +539,14 @@ class VideoFromComponents(VideoInput):
|
|||||||
|
|
||||||
# Encode video
|
# Encode video
|
||||||
for i, frame in enumerate(self.__components.images):
|
for i, frame in enumerate(self.__components.images):
|
||||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
if is_10bit:
|
||||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
# 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
|
||||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
|
||||||
|
frame = av.VideoFrame.from_ndarray(img, format='rgb48le')
|
||||||
|
else:
|
||||||
|
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||||
|
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||||
|
frame = frame.reformat(format=pix_fmt)
|
||||||
packet = video_stream.encode(frame)
|
packet = video_stream.encode(frame)
|
||||||
output.mux(packet)
|
output.mux(packet)
|
||||||
|
|
||||||
@ -534,3 +573,19 @@ class VideoFromComponents(VideoInput):
|
|||||||
return None
|
return None
|
||||||
#TODO Consider tracking duration and trimming at time of save?
|
#TODO Consider tracking duration and trimming at time of save?
|
||||||
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
|
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_video_input_accepts(values: list, input_info: dict | None) -> list:
|
||||||
|
"""Apply a VIDEO input's `accepts` declaration to its bound values.
|
||||||
|
|
||||||
|
Inputs declaring `accepts={"depth": 10}` receive uncapped videos.
|
||||||
|
For the rest, file-backed videos are replaced with copies that save as 8-bit by default,
|
||||||
|
so existing nodes keep producing 8-bit files.
|
||||||
|
VideoFromFile subclasses and other VideoInput implementations own their depth behavior and pass through unchanged.
|
||||||
|
"""
|
||||||
|
accepts = (input_info or {}).get("accepts") or {}
|
||||||
|
cap = None if accepts.get("depth", 8) >= 10 else 8
|
||||||
|
return [
|
||||||
|
value.with_bit_depth_cap(cap) if type(value) is VideoFromFile else value
|
||||||
|
for value in values
|
||||||
|
]
|
||||||
|
|||||||
@ -662,6 +662,26 @@ class Video(ComfyTypeIO):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
Type = VideoInput
|
Type = VideoInput
|
||||||
|
|
||||||
|
class Input(Input):
|
||||||
|
"""Video input socket.
|
||||||
|
|
||||||
|
`accepts` declares which video properties the node handles itself; only "depth" (8 or 10) is supported for now,
|
||||||
|
e.g. `accepts={"depth": 10}`. Inputs without it receive videos whose saved files are capped to 8-bit.
|
||||||
|
"""
|
||||||
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
|
extra_dict=None, raw_link: bool=None, advanced: bool=None, accepts: dict=None):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced)
|
||||||
|
if accepts is not None:
|
||||||
|
unknown_keys = set(accepts) - {"depth"}
|
||||||
|
if unknown_keys:
|
||||||
|
raise ValueError(f"Unsupported keys in Video.Input accepts: {sorted(unknown_keys)}")
|
||||||
|
if "depth" in accepts and accepts["depth"] not in (8, 10):
|
||||||
|
raise ValueError("Video.Input accepts['depth'] must be 8 or 10")
|
||||||
|
self.accepts = accepts
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return super().as_dict() | prune_dict({"accepts": self.accepts})
|
||||||
|
|
||||||
@comfytype(io_type="SVG")
|
@comfytype(io_type="SVG")
|
||||||
class SVG(ComfyTypeIO):
|
class SVG(ComfyTypeIO):
|
||||||
Type = _SVG
|
Type = _SVG
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
from .video_types import VideoContainer, VideoCodec, VideoBitDepth, VideoComponents
|
||||||
from .geometry_types import VOXEL, MESH, SPLAT, File3D
|
from .geometry_types import VOXEL, MESH, SPLAT, File3D
|
||||||
from .image_types import SVG
|
from .image_types import SVG
|
||||||
|
|
||||||
@ -6,6 +6,7 @@ __all__ = [
|
|||||||
# Utility Types
|
# Utility Types
|
||||||
"VideoContainer",
|
"VideoContainer",
|
||||||
"VideoCodec",
|
"VideoCodec",
|
||||||
|
"VideoBitDepth",
|
||||||
"VideoComponents",
|
"VideoComponents",
|
||||||
"VOXEL",
|
"VOXEL",
|
||||||
"MESH",
|
"MESH",
|
||||||
|
|||||||
@ -15,6 +15,23 @@ class VideoCodec(str, Enum):
|
|||||||
"""
|
"""
|
||||||
return [member.value for member in cls]
|
return [member.value for member in cls]
|
||||||
|
|
||||||
|
|
||||||
|
class VideoBitDepth(str, Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
BIT_8 = "8-bit"
|
||||||
|
BIT_10 = "10-bit"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_input(cls) -> list[str]:
|
||||||
|
"""Returns a list of bit depth names that can be used as node input."""
|
||||||
|
return [member.value for member in cls]
|
||||||
|
|
||||||
|
def bits(self) -> Optional[int]:
|
||||||
|
"""Returns the numeric bit depth, or None for AUTO."""
|
||||||
|
if self == VideoBitDepth.AUTO:
|
||||||
|
return None
|
||||||
|
return int(self.value.split("-")[0])
|
||||||
|
|
||||||
class VideoContainer(str, Enum):
|
class VideoContainer(str, Enum):
|
||||||
AUTO = "auto"
|
AUTO = "auto"
|
||||||
MP4 = "mp4"
|
MP4 = "mp4"
|
||||||
|
|||||||
@ -3,6 +3,8 @@ import av
|
|||||||
import torch
|
import torch
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import json
|
import json
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
@ -71,6 +73,15 @@ class SaveWEBM(io.ComfyNode):
|
|||||||
|
|
||||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
||||||
|
|
||||||
|
|
||||||
|
def _save_to_supports_bit_depth(video) -> bool:
|
||||||
|
try:
|
||||||
|
params = inspect.signature(video.save_to).parameters
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return True # not introspectable; assume the current contract
|
||||||
|
return "bit_depth" in params or any(p.kind is inspect.Parameter.VAR_KEYWORD for p in params.values())
|
||||||
|
|
||||||
|
|
||||||
class SaveVideo(io.ComfyNode):
|
class SaveVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -82,17 +93,26 @@ class SaveVideo(io.ComfyNode):
|
|||||||
essentials_category="Basics",
|
essentials_category="Basics",
|
||||||
description="Saves the input images to your ComfyUI output directory.",
|
description="Saves the input images to your ComfyUI output directory.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Video.Input("video", tooltip="The video to save."),
|
io.Video.Input("video", tooltip="The video to save.", accepts={"depth": 10}),
|
||||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||||
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||||
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||||
|
io.Combo.Input(
|
||||||
|
"bit_depth",
|
||||||
|
options=Types.VideoBitDepth.as_input(),
|
||||||
|
default="auto",
|
||||||
|
tooltip="Bit depth used when the video has to be re-encoded."
|
||||||
|
" 'auto' keeps the bit depth of the source video (videos created from images are saved as 8-bit)."
|
||||||
|
" 10-bit keeps smoother gradients with less banding, but some players may not support it.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
|
def execute(cls, video: Input.Video, filename_prefix, format: str, codec, bit_depth: str = "auto") -> io.NodeOutput:
|
||||||
width, height = video.get_dimensions()
|
width, height = video.get_dimensions()
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||||
filename_prefix,
|
filename_prefix,
|
||||||
@ -110,11 +130,22 @@ class SaveVideo(io.ComfyNode):
|
|||||||
if len(metadata) > 0:
|
if len(metadata) > 0:
|
||||||
saved_metadata = metadata
|
saved_metadata = metadata
|
||||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||||
|
bit_depth = Types.VideoBitDepth(bit_depth)
|
||||||
|
save_kwargs = {}
|
||||||
|
if bit_depth != Types.VideoBitDepth.AUTO:
|
||||||
|
if _save_to_supports_bit_depth(video):
|
||||||
|
save_kwargs["bit_depth"] = bit_depth
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
"%s.save_to() does not support bit_depth; saving at the source's default depth.",
|
||||||
|
type(video).__name__,
|
||||||
|
)
|
||||||
video.save_to(
|
video.save_to(
|
||||||
os.path.join(full_output_folder, file),
|
os.path.join(full_output_folder, file),
|
||||||
format=Types.VideoContainer(format),
|
format=Types.VideoContainer(format),
|
||||||
codec=codec,
|
codec=codec,
|
||||||
metadata=saved_metadata
|
metadata=saved_metadata,
|
||||||
|
**save_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
|
||||||
@ -226,7 +257,7 @@ class VideoSlice(io.ComfyNode):
|
|||||||
category="video",
|
category="video",
|
||||||
essentials_category="Video Tools",
|
essentials_category="Video Tools",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Video.Input("video"),
|
io.Video.Input("video", accepts={"depth": 10}),
|
||||||
io.Float.Input(
|
io.Float.Input(
|
||||||
"start_time",
|
"start_time",
|
||||||
default=0.0,
|
default=0.0,
|
||||||
|
|||||||
@ -43,6 +43,7 @@ from comfy_execution.utils import CurrentNodeContext
|
|||||||
from comfy_execution.asset_enrichment import enrich_output_with_assets
|
from comfy_execution.asset_enrichment import enrich_output_with_assets
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.latest import io, _io
|
from comfy_api.latest import io, _io
|
||||||
|
from comfy_api.latest._input_impl.video_types import apply_video_input_accepts
|
||||||
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
||||||
|
|
||||||
|
|
||||||
@ -164,7 +165,7 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
def mark_missing():
|
def mark_missing():
|
||||||
missing_keys[x] = True
|
missing_keys[x] = True
|
||||||
input_data_all[x] = (None,)
|
input_data_all[x] = (None,)
|
||||||
@ -182,6 +183,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
mark_missing()
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
obj = cached.outputs[output_index]
|
obj = cached.outputs[output_index]
|
||||||
|
if input_type == io.Video.io_type:
|
||||||
|
obj = apply_video_input_accepts(obj, input_info)
|
||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS):
|
elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS):
|
||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
|
|||||||
238
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
238
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import av
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from fractions import Fraction
|
||||||
|
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||||
|
from comfy_api.latest._input_impl.video_types import apply_video_input_accepts
|
||||||
|
from comfy_api.util.video_types import VideoComponents
|
||||||
|
from comfy_api.latest._util.video_types import VideoBitDepth
|
||||||
|
|
||||||
|
DECLARED = {"accepts": {"depth": 10}}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def gradient_components():
|
||||||
|
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
|
||||||
|
width, height, frames = 64, 64, 3
|
||||||
|
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
|
||||||
|
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def src8(gradient_components, tmp_path_factory):
|
||||||
|
"""8-bit h264 mp4 source file"""
|
||||||
|
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
|
||||||
|
VideoFromComponents(gradient_components).save_to(path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def src10(gradient_components, tmp_path_factory):
|
||||||
|
"""10-bit h264 mp4 source file"""
|
||||||
|
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
|
||||||
|
VideoFromComponents(gradient_components).save_to(path, bit_depth=VideoBitDepth.BIT_10)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def probe(path):
|
||||||
|
"""Return (codec, pix_fmt, bit_depth) of the first video stream"""
|
||||||
|
with av.open(path) as container:
|
||||||
|
stream = container.streams.video[0]
|
||||||
|
return (
|
||||||
|
stream.codec.name,
|
||||||
|
stream.format.name,
|
||||||
|
max(component.bits for component in stream.format.components),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decoded_levels(path):
|
||||||
|
"""Unique tonal levels in the first decoded frame (banding measure)"""
|
||||||
|
with av.open(path) as container:
|
||||||
|
frame = next(container.decode(container.streams.video[0]))
|
||||||
|
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
|
||||||
|
|
||||||
|
|
||||||
|
def video_packet_bytes(path):
|
||||||
|
"""Raw video packet payloads; identical to the source's only for a true remux"""
|
||||||
|
with av.open(path) as container:
|
||||||
|
return [bytes(packet) for packet in container.demux(container.streams.video[0]) if packet.size]
|
||||||
|
|
||||||
|
|
||||||
|
def test_components_save_bit_depths(src8, src10):
|
||||||
|
"""Default save stays 8-bit h264; 10-bit keeps h264 and clearly reduces banding"""
|
||||||
|
assert probe(src8) == ("h264", "yuv420p", 8)
|
||||||
|
assert probe(src10) == ("h264", "yuv420p10le", 10)
|
||||||
|
assert decoded_levels(src10) > 2 * decoded_levels(src8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_components_unsupported_codec_raises(gradient_components, tmp_path):
|
||||||
|
with pytest.raises(ValueError, match="H264"):
|
||||||
|
VideoFromComponents(gradient_components).save_to(str(tmp_path / "x.mp4"), codec="vp9")
|
||||||
|
|
||||||
|
|
||||||
|
def test_bit_depth_enum():
|
||||||
|
assert VideoBitDepth.as_input() == ["auto", "8-bit", "10-bit"]
|
||||||
|
assert [d.bits() for d in VideoBitDepth] == [None, 8, 10]
|
||||||
|
|
||||||
|
|
||||||
|
def test_10bit_source_remuxes_untouched(src10, tmp_path):
|
||||||
|
"""auto and a cap of 10 both keep a 10-bit stream untouched"""
|
||||||
|
for name, video in [("auto", VideoFromFile(src10)), ("cap10", VideoFromFile(src10).with_bit_depth_cap(10))]:
|
||||||
|
path = str(tmp_path / f"{name}.mp4")
|
||||||
|
video.save_to(path)
|
||||||
|
assert probe(path) == ("h264", "yuv420p10le", 10)
|
||||||
|
assert video_packet_bytes(path) == video_packet_bytes(src10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_8bit_source_remuxes_on_8bit_request(src8, tmp_path):
|
||||||
|
"""Neither explicit 8-bit nor a cap of 8 re-encodes an already 8-bit source"""
|
||||||
|
for name, save in [
|
||||||
|
("explicit", lambda p: VideoFromFile(src8).save_to(p, bit_depth="8-bit")),
|
||||||
|
("capped", lambda p: VideoFromFile(src8).with_bit_depth_cap(8).save_to(p)),
|
||||||
|
]:
|
||||||
|
path = str(tmp_path / f"{name}.mp4")
|
||||||
|
save(path)
|
||||||
|
assert video_packet_bytes(path) == video_packet_bytes(src8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim_keeps_source_depth(src10, tmp_path):
|
||||||
|
"""A re-encode forced by trimming preserves the source's 10-bit depth"""
|
||||||
|
path = str(tmp_path / "trim.mp4")
|
||||||
|
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
|
||||||
|
assert probe(path) == ("h264", "yuv420p10le", 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_depth_mismatch_forces_reencode(src8, src10, tmp_path):
|
||||||
|
"""An explicit depth that differs from the source's re-encodes instead of remuxing"""
|
||||||
|
down = str(tmp_path / "down8.mp4")
|
||||||
|
VideoFromFile(src10).save_to(down, bit_depth=VideoBitDepth.BIT_8)
|
||||||
|
assert probe(down) == ("h264", "yuv420p", 8)
|
||||||
|
|
||||||
|
up = str(tmp_path / "up10.mp4")
|
||||||
|
VideoFromFile(src8).save_to(up, bit_depth=VideoBitDepth.BIT_10)
|
||||||
|
assert probe(up) == ("h264", "yuv420p10le", 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bit_depth_cap(src10, tmp_path):
|
||||||
|
"""A cap of 8 makes saves default to 8-bit (also through as_trimmed), but an
|
||||||
|
explicit request wins, and tensor access keeps full precision"""
|
||||||
|
capped = VideoFromFile(src10).with_bit_depth_cap(8)
|
||||||
|
|
||||||
|
path = str(tmp_path / "capped.mp4")
|
||||||
|
capped.save_to(path)
|
||||||
|
assert probe(path) == ("h264", "yuv420p", 8)
|
||||||
|
|
||||||
|
trimmed = str(tmp_path / "trimmed.mp4")
|
||||||
|
capped.as_trimmed(0, 1 / 30, strict_duration=False).save_to(trimmed)
|
||||||
|
assert probe(trimmed) == ("h264", "yuv420p", 8)
|
||||||
|
|
||||||
|
explicit = str(tmp_path / "explicit10.mp4")
|
||||||
|
capped.save_to(explicit, bit_depth=VideoBitDepth.BIT_10)
|
||||||
|
assert probe(explicit) == ("h264", "yuv420p10le", 10)
|
||||||
|
|
||||||
|
images = capped.get_components().images
|
||||||
|
assert images.dtype == torch.float32
|
||||||
|
assert len(torch.unique(images[0, :, :, 0])) > 30 # ~13 levels if quantized to 8-bit
|
||||||
|
|
||||||
|
|
||||||
|
def test_accepts_binding_policy(gradient_components, src10, tmp_path):
|
||||||
|
"""Undeclared inputs get an 8-bit-capped copy of file videos; declared inputs
|
||||||
|
get uncapped videos; everything else passes through untouched"""
|
||||||
|
video = VideoFromFile(src10)
|
||||||
|
|
||||||
|
# undeclared input: capped copy that saves 8-bit
|
||||||
|
[capped] = apply_video_input_accepts([video], {"tooltip": "x"})
|
||||||
|
assert type(capped) is VideoFromFile and capped is not video
|
||||||
|
bound = str(tmp_path / "bound.mp4")
|
||||||
|
capped.save_to(bound)
|
||||||
|
assert probe(bound) == ("h264", "yuv420p", 8)
|
||||||
|
|
||||||
|
# declared input: original passes through; a cap from an earlier binding is lifted
|
||||||
|
assert apply_video_input_accepts([video], DECLARED)[0] is video
|
||||||
|
[lifted] = apply_video_input_accepts([capped], DECLARED)
|
||||||
|
lifted_path = str(tmp_path / "lifted.mp4")
|
||||||
|
lifted.save_to(lifted_path)
|
||||||
|
assert probe(lifted_path) == ("h264", "yuv420p10le", 10)
|
||||||
|
|
||||||
|
# declaring depth 8 is the same as not declaring
|
||||||
|
assert apply_video_input_accepts([video], {"accepts": {"depth": 8}})[0] is not video
|
||||||
|
|
||||||
|
# subclasses, component videos, custom implementations, and non-videos pass through
|
||||||
|
from comfy_api.latest._input import VideoInput as VideoInputABC
|
||||||
|
|
||||||
|
class SubVideo(VideoFromFile):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class CustomVideo(VideoInputABC):
|
||||||
|
def get_components(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def save_to(self, path, format=None, codec=None, metadata=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def as_trimmed(self, start_time=None, duration=None, strict_duration=False):
|
||||||
|
return self
|
||||||
|
|
||||||
|
passthrough = [SubVideo(src10), VideoFromComponents(gradient_components), CustomVideo(), "not a video", None]
|
||||||
|
assert apply_video_input_accepts(passthrough, None) == passthrough
|
||||||
|
|
||||||
|
|
||||||
|
def test_accepts_declaration():
|
||||||
|
"""Video.Input validates and serializes accepts; SaveVideo and VideoSlice declare it"""
|
||||||
|
from comfy_api.latest import io
|
||||||
|
import comfy_extras.nodes_video as nv
|
||||||
|
from comfy_execution.graph import get_input_info
|
||||||
|
|
||||||
|
assert io.Video.Input("video", accepts={"depth": 10}).as_dict()["accepts"] == {"depth": 10}
|
||||||
|
assert "accepts" not in io.Video.Input("video").as_dict()
|
||||||
|
with pytest.raises(ValueError, match="Unsupported keys"):
|
||||||
|
io.Video.Input("video", accepts={"codec": "h264"})
|
||||||
|
with pytest.raises(ValueError, match="must be 8 or 10"):
|
||||||
|
io.Video.Input("video", accepts={"depth": 12})
|
||||||
|
|
||||||
|
for node in (nv.SaveVideo, nv.VideoSlice):
|
||||||
|
_, _, info = get_input_info(node, "video", node.INPUT_TYPES())
|
||||||
|
assert info.get("accepts") == {"depth": 10}, node
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_video_node_bit_depth_handling(tmp_path, monkeypatch, caplog):
|
||||||
|
"""SaveVideo forwards bit_depth to a source that accepts it (the file is really 10-bit),
|
||||||
|
and for a legacy save_to that predates the parameter it warns and saves anyway instead of raising TypeError"""
|
||||||
|
import comfy_extras.nodes_video as nv
|
||||||
|
from comfy_api.latest._io import HiddenHolder
|
||||||
|
from comfy_api.latest._input import VideoInput as VideoInputABC
|
||||||
|
|
||||||
|
monkeypatch.setattr(nv.folder_paths, "get_output_directory", lambda: str(tmp_path))
|
||||||
|
monkeypatch.setattr(nv.SaveVideo, "hidden", HiddenHolder.from_dict(None))
|
||||||
|
|
||||||
|
class LegacyVideo(VideoInputABC):
|
||||||
|
def get_dimensions(self):
|
||||||
|
return 16, 16
|
||||||
|
|
||||||
|
def get_components(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def save_to(self, path, format=None, codec=None, metadata=None): # no bit_depth
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(b"data")
|
||||||
|
|
||||||
|
def as_trimmed(self, start_time=None, duration=None, strict_duration=False):
|
||||||
|
return self
|
||||||
|
|
||||||
|
# legacy source: an explicit 10-bit request must not crash; it warns and still saves
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
nv.SaveVideo.execute(LegacyVideo(), "legacy", "auto", "auto", bit_depth="10-bit")
|
||||||
|
assert "does not support bit_depth" in caplog.text
|
||||||
|
assert list(tmp_path.glob("legacy*"))
|
||||||
|
|
||||||
|
# supporting source: bit_depth reaches save_to, so the file really is 10-bit
|
||||||
|
ramp = torch.linspace(0.25, 0.30, 64).view(1, 1, 64, 1).expand(3, 64, 64, 3).contiguous()
|
||||||
|
nv.SaveVideo.execute(
|
||||||
|
VideoFromComponents(VideoComponents(images=ramp, frame_rate=Fraction(30))),
|
||||||
|
"supported", "auto", "auto", bit_depth="10-bit",
|
||||||
|
)
|
||||||
|
outs = list(tmp_path.glob("supported*.mp4"))
|
||||||
|
assert len(outs) == 1
|
||||||
|
assert probe(str(outs[0])) == ("h264", "yuv420p10le", 10)
|
||||||
Loading…
Reference in New Issue
Block a user