mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 16:50:17 +08:00
Support non-strict duration
This commit is contained in:
parent
37c2a960cb
commit
fa8241f85e
@ -35,7 +35,12 @@ class VideoInput(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def as_trimmed(self, start_time: float|None=None, duration: float|None=None) -> VideoInput|None:
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = False,
|
||||
) -> VideoInput | None:
|
||||
"""
|
||||
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
||||
|
||||
|
||||
@ -65,9 +65,6 @@ class VideoFromFile(VideoInput):
|
||||
self.__file = file
|
||||
self.__start_time = start_time
|
||||
self.__duration = duration
|
||||
if self.get_duration() < duration:
|
||||
raise ValueError(f"Can not initialize video of negative duration:\nSource duration: {self.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}")
|
||||
self.__duration = duration
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@ -101,20 +98,28 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
raw_duration = self._get_raw_duration()
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
if self.__duration:
|
||||
return self.__duration
|
||||
return min(self.__duration, duration_from_start)
|
||||
return duration_from_start
|
||||
|
||||
def _get_raw_duration(self) -> float:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base) - self.__start_time
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate) - self.start_time
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
@ -128,7 +133,7 @@ class VideoFromFile(VideoInput):
|
||||
for packet in frame_iterator:
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate) - self.start_time
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
@ -142,32 +147,39 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0 and not self.__start_time and not self.__duration:
|
||||
# 1. Prefer the frames field if available and usable
|
||||
if (
|
||||
video_stream.frames
|
||||
and video_stream.frames > 0
|
||||
and not self.__start_time
|
||||
and not self.__duration
|
||||
):
|
||||
return int(video_stream.frames)
|
||||
|
||||
if self.__duration:
|
||||
return self.__duration / float(video_stream.average_rate)
|
||||
# 2. Try to estimate from duration and average_rate using only metadata
|
||||
if container.duration is not None and video_stream.average_rate:
|
||||
duration_seconds = float(container.duration / av.time_base) - self.__start_time
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
if (
|
||||
getattr(video_stream, "duration", None) is not None
|
||||
and getattr(video_stream, "time_base", None) is not None
|
||||
and video_stream.average_rate
|
||||
):
|
||||
duration_seconds = float(video_stream.duration * video_stream.time_base) - self.start_time
|
||||
raw_duration = float(video_stream.duration * video_stream.time_base)
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
duration_seconds = min(self.__duration, duration_from_start)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
# 3. Last resort: decode frames and count them (streaming)
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
frame_count = 1
|
||||
start_pts = int(self.__start_time / video_stream.time_base)
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
@ -180,6 +192,8 @@ class VideoFromFile(VideoInput):
|
||||
else:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= end_pts:
|
||||
break
|
||||
frame_count += 1
|
||||
return frame_count
|
||||
|
||||
@ -219,11 +233,15 @@ class VideoFromFile(VideoInput):
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
video_stream = container.streams.video[0]
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
# Get video frames
|
||||
frames = []
|
||||
start_pts = int(self.__start_time / video_stream.time_base)
|
||||
end_pts = int((self.__start_time + self.__duration) / video_stream.time_base)
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
for frame in container.decode(video_stream):
|
||||
if frame.pts < start_pts:
|
||||
@ -248,20 +266,21 @@ class VideoFromFile(VideoInput):
|
||||
audio_frames = []
|
||||
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
||||
frames = itertools.chain.from_iterable(
|
||||
map(resample, container.decode(audio_stream)))
|
||||
map(resample, container.decode(audio_stream))
|
||||
)
|
||||
|
||||
has_first_frame = False
|
||||
for frame in frames:
|
||||
offset_seconds = self.__start_time - frame.pts * audio_stream.time_base
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||
if to_skip < frame.samples:
|
||||
has_first_frame = True
|
||||
break
|
||||
if has_first_frame:
|
||||
audio_frames.append(frame.to_ndarray()[...,to_skip:])
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
|
||||
for frame in frames:
|
||||
if frame.time > self.__start_time + self.__duration:
|
||||
if frame.time > start_time + self.__duration:
|
||||
break
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
@ -290,7 +309,7 @@ class VideoFromFile(VideoInput):
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
@ -309,10 +328,7 @@ class VideoFromFile(VideoInput):
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -346,15 +362,21 @@ class VideoFromFile(VideoInput):
|
||||
output_container.mux(packet)
|
||||
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
return video_stream
|
||||
if len(container.streams.video):
|
||||
return container.streams.video[0]
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def as_trimmed(self, start_time: float=0, duration: float=0) -> VideoInput|None:
|
||||
if self.get_duration() < start_time + duration:
|
||||
def as_trimmed(
|
||||
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
|
||||
) -> VideoInput | None:
|
||||
trimmed = VideoFromFile(
|
||||
self.get_stream_source(),
|
||||
start_time=start_time + self.__start_time,
|
||||
duration=duration + self.__duration,
|
||||
)
|
||||
if trimmed.get_duration() < duration and strict_duration:
|
||||
return None
|
||||
return VideoFromFile(self.get_stream_source(), start_time=start_time + self.__start_time, duration=duration + self.__duration)
|
||||
return trimmed
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
@ -369,7 +391,7 @@ class VideoFromComponents(VideoInput):
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def save_to(
|
||||
@ -377,7 +399,7 @@ class VideoFromComponents(VideoInput):
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
@ -428,7 +450,13 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
# Flush encoder
|
||||
output.mux(audio_stream.encode(None))
|
||||
def as_trimmed(self, start_time: float|None=None, duration: float|None=None) -> VideoInput|None:
|
||||
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = True,
|
||||
) -> VideoInput | None:
|
||||
if self.get_duration() < start_time + duration:
|
||||
return None
|
||||
#TODO Consider tracking duration and trimming at time of save?
|
||||
|
||||
@ -221,7 +221,7 @@ class VideoSlice(io.ComfyNode):
|
||||
"start_time",
|
||||
default=0.0,
|
||||
max=1e5,
|
||||
min=0.0,
|
||||
min=-1e5,
|
||||
step=0.001,
|
||||
tooltip="Start time in seconds",
|
||||
),
|
||||
@ -230,7 +230,7 @@ class VideoSlice(io.ComfyNode):
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
step=0.001,
|
||||
tooltip="Duration in seconds",
|
||||
tooltip="Duration in seconds, or 0 for unlimited duration",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
@ -240,7 +240,7 @@ class VideoSlice(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video, start_time, duration) -> io.NodeOutput:
|
||||
trimmed = video.as_trimmed(start_time, duration)
|
||||
trimmed = video.as_trimmed(start_time, duration, strict_duration=False)
|
||||
if trimmed is not None:
|
||||
return io.NodeOutput(trimmed)
|
||||
raise ValueError(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user