Support non-strict duration
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
Austin Mroz 2026-01-30 10:46:11 -08:00
parent 37c2a960cb
commit fa8241f85e
No known key found for this signature in database
3 changed files with 78 additions and 45 deletions

View File

@ -35,7 +35,12 @@ class VideoInput(ABC):
pass
@abstractmethod
def as_trimmed(self, start_time: float|None=None, duration: float|None=None) -> VideoInput|None:
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = False,
) -> VideoInput | None:
"""
Create a new VideoInput which is trimmed to have the corresponding start_time and duration

View File

@ -65,9 +65,6 @@ class VideoFromFile(VideoInput):
self.__file = file
self.__start_time = start_time
self.__duration = duration
if self.get_duration() < duration:
raise ValueError(f"Can not initialize video of negative duration:\nSource duration: {self.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}")
self.__duration = duration
def get_stream_source(self) -> str | io.BytesIO:
"""
@ -101,20 +98,28 @@ class VideoFromFile(VideoInput):
Returns:
Duration in seconds
"""
raw_duration = self._get_raw_duration()
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
if self.__duration:
return self.__duration
return min(self.__duration, duration_from_start)
return duration_from_start
def _get_raw_duration(self) -> float:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base) - self.__start_time
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate) - self.start_time
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
@ -128,7 +133,7 @@ class VideoFromFile(VideoInput):
for packet in frame_iterator:
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate) - self.start_time
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
@ -142,32 +147,39 @@ class VideoFromFile(VideoInput):
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0 and not self.__start_time and not self.__duration:
# 1. Prefer the frames field if available and usable
if (
video_stream.frames
and video_stream.frames > 0
and not self.__start_time
and not self.__duration
):
return int(video_stream.frames)
if self.__duration:
return self.__duration / float(video_stream.average_rate)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base) - self.__start_time
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base) - self.start_time
raw_duration = float(video_stream.duration * video_stream.time_base)
if self.__start_time < 0:
duration_from_start = min(raw_duration, -self.__start_time)
else:
duration_from_start = raw_duration - self.__start_time
duration_seconds = min(self.__duration, duration_from_start)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
frame_count = 1
start_pts = int(self.__start_time / video_stream.time_base)
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
frame_iterator = (
container.decode(video_stream)
@ -180,6 +192,8 @@ class VideoFromFile(VideoInput):
else:
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
for frame in frame_iterator:
if frame.pts >= end_pts:
break
frame_count += 1
return frame_count
@ -219,11 +233,15 @@ class VideoFromFile(VideoInput):
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
video_stream = container.streams.video[0]
video_stream = self._get_first_video_stream(container)
if self.__start_time < 0:
start_time = max(self._get_raw_duration() + self.__start_time, 0)
else:
start_time = self.__start_time
# Get video frames
frames = []
start_pts = int(self.__start_time / video_stream.time_base)
end_pts = int((self.__start_time + self.__duration) / video_stream.time_base)
start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
for frame in container.decode(video_stream):
if frame.pts < start_pts:
@ -248,20 +266,21 @@ class VideoFromFile(VideoInput):
audio_frames = []
resample = av.audio.resampler.AudioResampler(format='fltp').resample
frames = itertools.chain.from_iterable(
map(resample, container.decode(audio_stream)))
map(resample, container.decode(audio_stream))
)
has_first_frame = False
for frame in frames:
offset_seconds = self.__start_time - frame.pts * audio_stream.time_base
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = int(offset_seconds * audio_stream.sample_rate)
if to_skip < frame.samples:
has_first_frame = True
break
if has_first_frame:
audio_frames.append(frame.to_ndarray()[...,to_skip:])
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if frame.time > self.__start_time + self.__duration:
if frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
@ -290,7 +309,7 @@ class VideoFromFile(VideoInput):
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
metadata: Optional[dict] = None,
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
@ -309,10 +328,7 @@ class VideoFromFile(VideoInput):
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path,
format=format,
codec=codec,
metadata=metadata
path, format=format, codec=codec, metadata=metadata
)
streams = container.streams
@ -346,15 +362,21 @@ class VideoFromFile(VideoInput):
output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
if len(container.streams.video):
return container.streams.video[0]
raise ValueError(f"No video stream found in file '{self.__file}'")
def as_trimmed(self, start_time: float=0, duration: float=0) -> VideoInput|None:
if self.get_duration() < start_time + duration:
def as_trimmed(
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
) -> VideoInput | None:
trimmed = VideoFromFile(
self.get_stream_source(),
start_time=start_time + self.__start_time,
duration=duration + self.__duration,
)
if trimmed.get_duration() < duration and strict_duration:
return None
return VideoFromFile(self.get_stream_source(), start_time=start_time + self.__start_time, duration=duration + self.__duration)
return trimmed
class VideoFromComponents(VideoInput):
@ -369,7 +391,7 @@ class VideoFromComponents(VideoInput):
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
frame_rate=self.__components.frame_rate
frame_rate=self.__components.frame_rate,
)
def save_to(
@ -377,7 +399,7 @@ class VideoFromComponents(VideoInput):
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
metadata: Optional[dict] = None,
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
@ -428,7 +450,13 @@ class VideoFromComponents(VideoInput):
# Flush encoder
output.mux(audio_stream.encode(None))
def as_trimmed(self, start_time: float|None=None, duration: float|None=None) -> VideoInput|None:
def as_trimmed(
self,
start_time: float | None = None,
duration: float | None = None,
strict_duration: bool = True,
) -> VideoInput | None:
if self.get_duration() < start_time + duration:
return None
#TODO Consider tracking duration and trimming at time of save?

View File

@ -221,7 +221,7 @@ class VideoSlice(io.ComfyNode):
"start_time",
default=0.0,
max=1e5,
min=0.0,
min=-1e5,
step=0.001,
tooltip="Start time in seconds",
),
@ -230,7 +230,7 @@ class VideoSlice(io.ComfyNode):
default=0.0,
min=0.0,
step=0.001,
tooltip="Duration in seconds",
tooltip="Duration in seconds, or 0 for unlimited duration",
),
],
outputs=[
@ -240,7 +240,7 @@ class VideoSlice(io.ComfyNode):
@classmethod
def execute(cls, video, start_time, duration) -> io.NodeOutput:
trimmed = video.as_trimmed(start_time, duration)
trimmed = video.as_trimmed(start_time, duration, strict_duration=False)
if trimmed is not None:
return io.NodeOutput(trimmed)
raise ValueError(