mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
fix VideoFromFile stream source to _ReentrantBytesIO for parallel async use
This commit is contained in:
parent
807538fe6c
commit
28db2757e1
@ -13,6 +13,124 @@ import torch
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
class _ReentrantBytesIO(io.BytesIO):
|
||||
"""Read-only, seekable BytesIO-compatible view over shared immutable bytes."""
|
||||
|
||||
def __init__(self, data: bytes):
|
||||
super().__init__(b"") # Initialize base BytesIO with an empty buffer; we do not use its internal storage.
|
||||
if data is None:
|
||||
raise TypeError("data must be bytes, not None")
|
||||
self._data = data
|
||||
self._pos = 0
|
||||
self._len = len(data)
|
||||
|
||||
def getvalue(self) -> bytes:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
return self._data
|
||||
|
||||
def getbuffer(self) -> memoryview:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
return memoryview(self._data)
|
||||
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
def writable(self) -> bool:
|
||||
return False
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return True
|
||||
|
||||
def tell(self) -> int:
|
||||
return self._pos
|
||||
|
||||
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
if whence == io.SEEK_SET:
|
||||
new_pos = offset
|
||||
elif whence == io.SEEK_CUR:
|
||||
new_pos = self._pos + offset
|
||||
elif whence == io.SEEK_END:
|
||||
new_pos = self._len + offset
|
||||
else:
|
||||
raise ValueError(f"Invalid whence: {whence}")
|
||||
if new_pos < 0:
|
||||
raise ValueError("Negative seek position")
|
||||
self._pos = new_pos
|
||||
return self._pos
|
||||
|
||||
def readinto(self, b) -> int:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
mv = memoryview(b)
|
||||
if mv.readonly:
|
||||
raise TypeError("readinto() argument must be writable")
|
||||
mv = mv.cast("B")
|
||||
if self._pos >= self._len:
|
||||
return 0
|
||||
n = min(len(mv), self._len - self._pos)
|
||||
mv[:n] = self._data[self._pos:self._pos + n]
|
||||
self._pos += n
|
||||
return n
|
||||
|
||||
def readinto1(self, b) -> int:
|
||||
return self.readinto(b)
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
if size is None or size < 0:
|
||||
size = self._len - self._pos
|
||||
if self._pos >= self._len:
|
||||
return b""
|
||||
end = min(self._pos + size, self._len)
|
||||
out = self._data[self._pos:end]
|
||||
self._pos = end
|
||||
return out
|
||||
|
||||
def read1(self, size: int = -1) -> bytes:
|
||||
return self.read(size)
|
||||
|
||||
def readline(self, size: int = -1) -> bytes:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
if self._pos >= self._len:
|
||||
return b""
|
||||
end_limit = self._len if size is None or size < 0 else min(self._len, self._pos + size)
|
||||
nl = self._data.find(b"\n", self._pos, end_limit)
|
||||
end = (nl + 1) if nl != -1 else end_limit
|
||||
out = self._data[self._pos:end]
|
||||
self._pos = end
|
||||
return out
|
||||
|
||||
def readlines(self, hint: int = -1) -> list[bytes]:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
lines: list[bytes] = []
|
||||
total = 0
|
||||
while True:
|
||||
line = self.readline()
|
||||
if not line:
|
||||
break
|
||||
lines.append(line)
|
||||
total += len(line)
|
||||
if hint is not None and 0 <= hint <= total:
|
||||
break
|
||||
return lines
|
||||
|
||||
def write(self, b) -> int:
|
||||
raise io.UnsupportedOperation("not writable")
|
||||
|
||||
def writelines(self, lines) -> None:
|
||||
raise io.UnsupportedOperation("not writable")
|
||||
|
||||
def truncate(self, size: int | None = None) -> int:
|
||||
raise io.UnsupportedOperation("not writable")
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
@ -57,21 +175,31 @@ class VideoFromFile(VideoInput):
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
__data: str | bytes
|
||||
|
||||
def __init__(self, file: str | io.BytesIO | bytes | bytearray | memoryview):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
if isinstance(file, str):
|
||||
self.__data = file
|
||||
elif isinstance(file, io.BytesIO):
|
||||
# Snapshot to immutable bytes once to ensure re-entrant, parallel-safe readers.
|
||||
self.__data = file.getbuffer().tobytes()
|
||||
elif isinstance(file, (bytes, bytearray, memoryview)):
|
||||
self.__data = bytes(file)
|
||||
else:
|
||||
raise TypeError(f"Unsupported video source type: {type(file)!r}")
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
Return the underlying file source for efficient streaming.
|
||||
This avoids unnecessary memory copies when the source is already a file path.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
return self.__file
|
||||
if isinstance(self.__data, str):
|
||||
return self.__data
|
||||
return _ReentrantBytesIO(self.__data)
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
@ -80,14 +208,12 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
with av.open(self.get_stream_source(), mode="r") as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == 'video':
|
||||
assert isinstance(stream, av.VideoStream)
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
raise ValueError(f"No video stream found in {self._source_label()}")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
@ -96,9 +222,7 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
with av.open(self.get_stream_source(), mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
@ -119,17 +243,14 @@ class VideoFromFile(VideoInput):
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
raise ValueError(f"Could not determine duration for file '{self._source_label()}'")
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""
|
||||
Returns the number of frames in the video without materializing them as
|
||||
torch tensors.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
with av.open(self.get_stream_source(), mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
@ -160,7 +281,7 @@ class VideoFromFile(VideoInput):
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
raise ValueError(f"Could not determine frame count for file '{self._source_label()}'")
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@ -168,10 +289,7 @@ class VideoFromFile(VideoInput):
|
||||
Returns the average frame rate of the video using container metadata
|
||||
without decoding all frames.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
with av.open(self.get_stream_source(), mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
|
||||
if video_stream.average_rate:
|
||||
@ -193,9 +311,7 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
with av.open(self.get_stream_source(), mode='r') as container:
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
@ -239,11 +355,8 @@ class VideoFromFile(VideoInput):
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
with av.open(self.get_stream_source(), mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
@ -252,9 +365,7 @@ class VideoFromFile(VideoInput):
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
with av.open(self.get_stream_source(), mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
@ -306,9 +417,12 @@ class VideoFromFile(VideoInput):
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
raise ValueError(f"No video stream found in file '{self._source_label()}'")
|
||||
return video_stream
|
||||
|
||||
def _source_label(self) -> str:
|
||||
return self.__data if isinstance(self.__data, str) else f"<in-memory video: {len(self.__data)} bytes>"
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user