This commit is contained in:
Alexander Piskun 2026-01-09 12:29:31 +00:00 committed by GitHub
commit d79ff172a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,6 +13,124 @@ import torch
from .._util import VideoContainer, VideoCodec, VideoComponents from .._util import VideoContainer, VideoCodec, VideoComponents
class _ReentrantBytesIO(io.BytesIO):
"""Read-only, seekable BytesIO-compatible view over shared immutable bytes."""
def __init__(self, data: bytes):
super().__init__(b"") # Initialize base BytesIO with an empty buffer; we do not use its internal storage.
if data is None:
raise TypeError("data must be bytes, not None")
self._data = data
self._pos = 0
self._len = len(data)
def getvalue(self) -> bytes:
if self.closed:
raise ValueError("I/O operation on closed file.")
return self._data
def getbuffer(self) -> memoryview:
if self.closed:
raise ValueError("I/O operation on closed file.")
return memoryview(self._data)
def readable(self) -> bool:
return True
def writable(self) -> bool:
return False
def seekable(self) -> bool:
return True
def tell(self) -> int:
return self._pos
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if self.closed:
raise ValueError("I/O operation on closed file.")
if whence == io.SEEK_SET:
new_pos = offset
elif whence == io.SEEK_CUR:
new_pos = self._pos + offset
elif whence == io.SEEK_END:
new_pos = self._len + offset
else:
raise ValueError(f"Invalid whence: {whence}")
if new_pos < 0:
raise ValueError("Negative seek position")
self._pos = new_pos
return self._pos
def readinto(self, b) -> int:
if self.closed:
raise ValueError("I/O operation on closed file.")
mv = memoryview(b)
if mv.readonly:
raise TypeError("readinto() argument must be writable")
mv = mv.cast("B")
if self._pos >= self._len:
return 0
n = min(len(mv), self._len - self._pos)
mv[:n] = self._data[self._pos:self._pos + n]
self._pos += n
return n
def readinto1(self, b) -> int:
return self.readinto(b)
def read(self, size: int = -1) -> bytes:
if self.closed:
raise ValueError("I/O operation on closed file.")
if size is None or size < 0:
size = self._len - self._pos
if self._pos >= self._len:
return b""
end = min(self._pos + size, self._len)
out = self._data[self._pos:end]
self._pos = end
return out
def read1(self, size: int = -1) -> bytes:
return self.read(size)
def readline(self, size: int = -1) -> bytes:
if self.closed:
raise ValueError("I/O operation on closed file.")
if self._pos >= self._len:
return b""
end_limit = self._len if size is None or size < 0 else min(self._len, self._pos + size)
nl = self._data.find(b"\n", self._pos, end_limit)
end = (nl + 1) if nl != -1 else end_limit
out = self._data[self._pos:end]
self._pos = end
return out
def readlines(self, hint: int = -1) -> list[bytes]:
if self.closed:
raise ValueError("I/O operation on closed file.")
lines: list[bytes] = []
total = 0
while True:
line = self.readline()
if not line:
break
lines.append(line)
total += len(line)
if hint is not None and 0 <= hint <= total:
break
return lines
def write(self, b) -> int:
raise io.UnsupportedOperation("not writable")
def writelines(self, lines) -> None:
raise io.UnsupportedOperation("not writable")
def truncate(self, size: int | None = None) -> int:
raise io.UnsupportedOperation("not writable")
def container_to_output_format(container_format: str | None) -> str | None: def container_to_output_format(container_format: str | None) -> str | None:
""" """
A container's `format` may be a comma-separated list of formats. A container's `format` may be a comma-separated list of formats.
@ -57,21 +175,31 @@ class VideoFromFile(VideoInput):
Class representing video input from a file. Class representing video input from a file.
""" """
def __init__(self, file: str | io.BytesIO): __data: str | bytes
def __init__(self, file: str | io.BytesIO | bytes | bytearray | memoryview):
""" """
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents. containing the file contents.
""" """
self.__file = file if isinstance(file, str):
self.__data = file
elif isinstance(file, io.BytesIO):
# Snapshot to immutable bytes once to ensure re-entrant, parallel-safe readers.
self.__data = file.getbuffer().tobytes()
elif isinstance(file, (bytes, bytearray, memoryview)):
self.__data = bytes(file)
else:
raise TypeError(f"Unsupported video source type: {type(file)!r}")
def get_stream_source(self) -> str | io.BytesIO: def get_stream_source(self) -> str | io.BytesIO:
""" """
Return the underlying file source for efficient streaming. Return the underlying file source for efficient streaming.
This avoids unnecessary memory copies when the source is already a file path. This avoids unnecessary memory copies when the source is already a file path.
""" """
if isinstance(self.__file, io.BytesIO): if isinstance(self.__data, str):
self.__file.seek(0) return self.__data
return self.__file return _ReentrantBytesIO(self.__data)
def get_dimensions(self) -> tuple[int, int]: def get_dimensions(self) -> tuple[int, int]:
""" """
@ -80,14 +208,12 @@ class VideoFromFile(VideoInput):
Returns: Returns:
Tuple of (width, height) Tuple of (width, height)
""" """
if isinstance(self.__file, io.BytesIO): with av.open(self.get_stream_source(), mode="r") as container:
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
for stream in container.streams: for stream in container.streams:
if stream.type == 'video': if stream.type == 'video':
assert isinstance(stream, av.VideoStream) assert isinstance(stream, av.VideoStream)
return stream.width, stream.height return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'") raise ValueError(f"No video stream found in {self._source_label()}")
def get_duration(self) -> float: def get_duration(self) -> float:
""" """
@ -96,9 +222,7 @@ class VideoFromFile(VideoInput):
Returns: Returns:
Duration in seconds Duration in seconds
""" """
if isinstance(self.__file, io.BytesIO): with av.open(self.get_stream_source(), mode="r") as container:
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None: if container.duration is not None:
return float(container.duration / av.time_base) return float(container.duration / av.time_base)
@ -119,17 +243,14 @@ class VideoFromFile(VideoInput):
if frame_count > 0: if frame_count > 0:
return float(frame_count / video_stream.average_rate) return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'") raise ValueError(f"Could not determine duration for file '{self._source_label()}'")
def get_frame_count(self) -> int: def get_frame_count(self) -> int:
""" """
Returns the number of frames in the video without materializing them as Returns the number of frames in the video without materializing them as
torch tensors. torch tensors.
""" """
if isinstance(self.__file, io.BytesIO): with av.open(self.get_stream_source(), mode="r") as container:
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container) video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available # 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0: if video_stream.frames and video_stream.frames > 0:
@ -160,7 +281,7 @@ class VideoFromFile(VideoInput):
frame_count += 1 frame_count += 1
if frame_count == 0: if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'") raise ValueError(f"Could not determine frame count for file '{self._source_label()}'")
return frame_count return frame_count
def get_frame_rate(self) -> Fraction: def get_frame_rate(self) -> Fraction:
@ -168,10 +289,7 @@ class VideoFromFile(VideoInput):
Returns the average frame rate of the video using container metadata Returns the average frame rate of the video using container metadata
without decoding all frames. without decoding all frames.
""" """
if isinstance(self.__file, io.BytesIO): with av.open(self.get_stream_source(), mode="r") as container:
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container) video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like) # Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate: if video_stream.average_rate:
@ -193,9 +311,7 @@ class VideoFromFile(VideoInput):
Returns: Returns:
Container format as string Container format as string
""" """
if isinstance(self.__file, io.BytesIO): with av.open(self.get_stream_source(), mode='r') as container:
self.__file.seek(0)
with av.open(self.__file, mode='r') as container:
return container.format.name return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents: def get_components_internal(self, container: InputContainer) -> VideoComponents:
@ -239,11 +355,8 @@ class VideoFromFile(VideoInput):
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents: def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO): with av.open(self.get_stream_source(), mode='r') as container:
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container) return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to( def save_to(
self, self,
@ -252,9 +365,7 @@ class VideoFromFile(VideoInput):
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None
): ):
if isinstance(self.__file, io.BytesIO): with av.open(self.get_stream_source(), mode='r') as container:
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
container_format = container.format.name container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True reuse_streams = True
@ -306,9 +417,12 @@ class VideoFromFile(VideoInput):
def _get_first_video_stream(self, container: InputContainer): def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None) video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None: if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'") raise ValueError(f"No video stream found in file '{self._source_label()}'")
return video_stream return video_stream
def _source_label(self) -> str:
return self.__data if isinstance(self.__data, str) else f"<in-memory video: {len(self.__data)} bytes>"
class VideoFromComponents(VideoInput): class VideoFromComponents(VideoInput):
""" """