add get_frame_count and get_frame_rate methods to VideoInput class (#10851)

This commit is contained in:
Alexander Piskun 2025-11-24 20:24:29 +02:00 committed by GitHub
parent 3bd71554a2
commit 1286fcfe40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 106 additions and 9 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from fractions import Fraction
from typing import Optional, Union, IO from typing import Optional, Union, IO
import io import io
import av import av
@ -72,6 +73,33 @@ class VideoInput(ABC):
frame_count = components.images.shape[0] frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate) return float(frame_count / components.frame_rate)
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video.
Default implementation uses :meth:`get_components`, which may require
loading all frames into memory. File-based implementations should
override this method and use container/stream metadata instead.
Returns:
Total number of frames as an integer.
"""
return int(self.get_components().images.shape[0])
def get_frame_rate(self) -> Fraction:
"""
Returns the frame rate of the video.
Default implementation materializes the video into memory via
`get_components()`. Subclasses that can inspect the underlying
container (e.g. `VideoFromFile`) should override this with a more
efficient implementation.
Returns:
Frame rate as a Fraction.
"""
return self.get_components().frame_rate
def get_container_format(self) -> str: def get_container_format(self) -> str:
""" """
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').

View File

@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
raise ValueError(f"Could not determine duration for file '{self.__file}'") raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video without materializing them as
torch tensors.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
return frame_count
def get_frame_rate(self) -> Fraction:
"""
Returns the average frame rate of the video using container metadata
without decoding all frames.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate:
return Fraction(video_stream.average_rate)
# Fallback: estimate from frames + duration if available
if video_stream.frames and container.duration:
duration_seconds = float(container.duration / av.time_base)
if duration_seconds > 0:
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
# Last resort: match get_components_internal default
return Fraction(1)
def get_container_format(self) -> str: def get_container_format(self) -> str:
""" """
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
packet.stream = stream_map[packet.stream] packet.stream = stream_map[packet.stream]
output_container.mux(packet) output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
class VideoFromComponents(VideoInput): class VideoFromComponents(VideoInput):
""" """
Class representing video input from tensors. Class representing video input from tensors.

View File

@ -5,8 +5,7 @@ import aiohttp
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.input.video_types import VideoInput from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import topaz_api from comfy_api_nodes.apis import topaz_api
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
@ -282,7 +281,7 @@ class TopazVideoEnhance(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
video: VideoInput, video: Input.Video,
upscaler_enabled: bool, upscaler_enabled: bool,
upscaler_model: str, upscaler_model: str,
upscaler_resolution: str, upscaler_resolution: str,
@ -297,12 +296,10 @@ class TopazVideoEnhance(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if upscaler_enabled is False and interpolation_enabled is False: if upscaler_enabled is False and interpolation_enabled is False:
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
src_width, src_height = video.get_dimensions()
video_components = video.get_components()
src_frame_rate = int(video_components.frame_rate)
duration_sec = video.get_duration()
estimated_frames = int(duration_sec * src_frame_rate)
validate_container_format_is_mp4(video) validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
src_video_stream = video.get_stream_source() src_video_stream = video.get_stream_source()
target_width = src_width target_width = src_width
target_height = src_height target_height = src_height
@ -338,7 +335,7 @@ class TopazVideoEnhance(IO.ComfyNode):
container="mp4", container="mp4",
size=get_fs_object_size(src_video_stream), size=get_fs_object_size(src_video_stream),
duration=int(duration_sec), duration=int(duration_sec),
frameCount=estimated_frames, frameCount=video.get_frame_count(),
frameRate=src_frame_rate, frameRate=src_frame_rate,
resolution=topaz_api.Resolution(width=src_width, height=src_height), resolution=topaz_api.Resolution(width=src_width, height=src_height),
), ),