fix VideoFromFile stream source to _ReentrantBytesIO for parallel async use

This commit is contained in:
bigcat88 2025-12-21 15:21:42 +02:00
parent 807538fe6c
commit 28db2757e1
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721

View File

@ -13,6 +13,124 @@ import torch
from .._util import VideoContainer, VideoCodec, VideoComponents
class _ReentrantBytesIO(io.BytesIO):
"""Read-only, seekable BytesIO-compatible view over shared immutable bytes."""
def __init__(self, data: bytes):
super().__init__(b"") # Initialize base BytesIO with an empty buffer; we do not use its internal storage.
if data is None:
raise TypeError("data must be bytes, not None")
self._data = data
self._pos = 0
self._len = len(data)
def getvalue(self) -> bytes:
if self.closed:
raise ValueError("I/O operation on closed file.")
return self._data
def getbuffer(self) -> memoryview:
if self.closed:
raise ValueError("I/O operation on closed file.")
return memoryview(self._data)
def readable(self) -> bool:
return True
def writable(self) -> bool:
return False
def seekable(self) -> bool:
return True
def tell(self) -> int:
return self._pos
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if self.closed:
raise ValueError("I/O operation on closed file.")
if whence == io.SEEK_SET:
new_pos = offset
elif whence == io.SEEK_CUR:
new_pos = self._pos + offset
elif whence == io.SEEK_END:
new_pos = self._len + offset
else:
raise ValueError(f"Invalid whence: {whence}")
if new_pos < 0:
raise ValueError("Negative seek position")
self._pos = new_pos
return self._pos
def readinto(self, b) -> int:
if self.closed:
raise ValueError("I/O operation on closed file.")
mv = memoryview(b)
if mv.readonly:
raise TypeError("readinto() argument must be writable")
mv = mv.cast("B")
if self._pos >= self._len:
return 0
n = min(len(mv), self._len - self._pos)
mv[:n] = self._data[self._pos:self._pos + n]
self._pos += n
return n
def readinto1(self, b) -> int:
return self.readinto(b)
def read(self, size: int = -1) -> bytes:
if self.closed:
raise ValueError("I/O operation on closed file.")
if size is None or size < 0:
size = self._len - self._pos
if self._pos >= self._len:
return b""
end = min(self._pos + size, self._len)
out = self._data[self._pos:end]
self._pos = end
return out
def read1(self, size: int = -1) -> bytes:
return self.read(size)
def readline(self, size: int = -1) -> bytes:
if self.closed:
raise ValueError("I/O operation on closed file.")
if self._pos >= self._len:
return b""
end_limit = self._len if size is None or size < 0 else min(self._len, self._pos + size)
nl = self._data.find(b"\n", self._pos, end_limit)
end = (nl + 1) if nl != -1 else end_limit
out = self._data[self._pos:end]
self._pos = end
return out
def readlines(self, hint: int = -1) -> list[bytes]:
if self.closed:
raise ValueError("I/O operation on closed file.")
lines: list[bytes] = []
total = 0
while True:
line = self.readline()
if not line:
break
lines.append(line)
total += len(line)
if hint is not None and 0 <= hint <= total:
break
return lines
def write(self, b) -> int:
raise io.UnsupportedOperation("not writable")
def writelines(self, lines) -> None:
raise io.UnsupportedOperation("not writable")
def truncate(self, size: int | None = None) -> int:
raise io.UnsupportedOperation("not writable")
def container_to_output_format(container_format: str | None) -> str | None:
"""
A container's `format` may be a comma-separated list of formats.
@ -57,21 +175,31 @@ class VideoFromFile(VideoInput):
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
__data: str | bytes
def __init__(self, file: str | io.BytesIO | bytes | bytearray | memoryview):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
if isinstance(file, str):
self.__data = file
elif isinstance(file, io.BytesIO):
# Snapshot to immutable bytes once to ensure re-entrant, parallel-safe readers.
self.__data = file.getbuffer().tobytes()
elif isinstance(file, (bytes, bytearray, memoryview)):
self.__data = bytes(file)
else:
raise TypeError(f"Unsupported video source type: {type(file)!r}")
def get_stream_source(self) -> str | io.BytesIO:
"""
Return the underlying file source for efficient streaming.
This avoids unnecessary memory copies when the source is already a file path.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
return self.__file
if isinstance(self.__data, str):
return self.__data
return _ReentrantBytesIO(self.__data)
def get_dimensions(self) -> tuple[int, int]:
"""
@ -80,14 +208,12 @@ class VideoFromFile(VideoInput):
Returns:
Tuple of (width, height)
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
with av.open(self.get_stream_source(), mode="r") as container:
for stream in container.streams:
if stream.type == 'video':
assert isinstance(stream, av.VideoStream)
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
raise ValueError(f"No video stream found in {self._source_label()}")
def get_duration(self) -> float:
"""
@ -96,9 +222,7 @@ class VideoFromFile(VideoInput):
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
with av.open(self.get_stream_source(), mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
@ -119,17 +243,14 @@ class VideoFromFile(VideoInput):
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
raise ValueError(f"Could not determine duration for file '{self._source_label()}'")
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video without materializing them as
torch tensors.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
with av.open(self.get_stream_source(), mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
@ -160,7 +281,7 @@ class VideoFromFile(VideoInput):
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
raise ValueError(f"Could not determine frame count for file '{self._source_label()}'")
return frame_count
def get_frame_rate(self) -> Fraction:
@ -168,10 +289,7 @@ class VideoFromFile(VideoInput):
Returns the average frame rate of the video using container metadata
without decoding all frames.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
with av.open(self.get_stream_source(), mode="r") as container:
video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate:
@ -193,9 +311,7 @@ class VideoFromFile(VideoInput):
Returns:
Container format as string
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode='r') as container:
with av.open(self.get_stream_source(), mode='r') as container:
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
@ -239,11 +355,8 @@ class VideoFromFile(VideoInput):
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
with av.open(self.get_stream_source(), mode='r') as container:
return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
self,
@ -252,9 +365,7 @@ class VideoFromFile(VideoInput):
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
with av.open(self.get_stream_source(), mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True
@ -306,9 +417,12 @@ class VideoFromFile(VideoInput):
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
raise ValueError(f"No video stream found in file '{self._source_label()}'")
return video_stream
def _source_label(self) -> str:
return self.__data if isinstance(self.__data, str) else f"<in-memory video: {len(self.__data)} bytes>"
class VideoFromComponents(VideoInput):
"""