mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Support non-strict duration
This commit is contained in:
parent
37c2a960cb
commit
fa8241f85e
@ -35,7 +35,12 @@ class VideoInput(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def as_trimmed(self, start_time: float|None=None, duration: float|None=None) -> VideoInput|None:
|
def as_trimmed(
|
||||||
|
self,
|
||||||
|
start_time: float | None = None,
|
||||||
|
duration: float | None = None,
|
||||||
|
strict_duration: bool = False,
|
||||||
|
) -> VideoInput | None:
|
||||||
"""
|
"""
|
||||||
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
||||||
|
|
||||||
|
|||||||
@ -65,9 +65,6 @@ class VideoFromFile(VideoInput):
|
|||||||
self.__file = file
|
self.__file = file
|
||||||
self.__start_time = start_time
|
self.__start_time = start_time
|
||||||
self.__duration = duration
|
self.__duration = duration
|
||||||
if self.get_duration() < duration:
|
|
||||||
raise ValueError(f"Can not initialize video of negative duration:\nSource duration: {self.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}")
|
|
||||||
self.__duration = duration
|
|
||||||
|
|
||||||
def get_stream_source(self) -> str | io.BytesIO:
|
def get_stream_source(self) -> str | io.BytesIO:
|
||||||
"""
|
"""
|
||||||
@ -101,20 +98,28 @@ class VideoFromFile(VideoInput):
|
|||||||
Returns:
|
Returns:
|
||||||
Duration in seconds
|
Duration in seconds
|
||||||
"""
|
"""
|
||||||
|
raw_duration = self._get_raw_duration()
|
||||||
|
if self.__start_time < 0:
|
||||||
|
duration_from_start = min(raw_duration, -self.__start_time)
|
||||||
|
else:
|
||||||
|
duration_from_start = raw_duration - self.__start_time
|
||||||
if self.__duration:
|
if self.__duration:
|
||||||
return self.__duration
|
return min(self.__duration, duration_from_start)
|
||||||
|
return duration_from_start
|
||||||
|
|
||||||
|
def _get_raw_duration(self) -> float:
|
||||||
if isinstance(self.__file, io.BytesIO):
|
if isinstance(self.__file, io.BytesIO):
|
||||||
self.__file.seek(0)
|
self.__file.seek(0)
|
||||||
with av.open(self.__file, mode="r") as container:
|
with av.open(self.__file, mode="r") as container:
|
||||||
if container.duration is not None:
|
if container.duration is not None:
|
||||||
return float(container.duration / av.time_base) - self.__start_time
|
return float(container.duration / av.time_base)
|
||||||
|
|
||||||
# Fallback: calculate from frame count and frame rate
|
# Fallback: calculate from frame count and frame rate
|
||||||
video_stream = next(
|
video_stream = next(
|
||||||
(s for s in container.streams if s.type == "video"), None
|
(s for s in container.streams if s.type == "video"), None
|
||||||
)
|
)
|
||||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||||
return float(video_stream.frames / video_stream.average_rate) - self.start_time
|
return float(video_stream.frames / video_stream.average_rate)
|
||||||
|
|
||||||
# Last resort: decode frames to count them
|
# Last resort: decode frames to count them
|
||||||
if video_stream and video_stream.average_rate:
|
if video_stream and video_stream.average_rate:
|
||||||
@ -128,7 +133,7 @@ class VideoFromFile(VideoInput):
|
|||||||
for packet in frame_iterator:
|
for packet in frame_iterator:
|
||||||
frame_count += 1
|
frame_count += 1
|
||||||
if frame_count > 0:
|
if frame_count > 0:
|
||||||
return float(frame_count / video_stream.average_rate) - self.start_time
|
return float(frame_count / video_stream.average_rate)
|
||||||
|
|
||||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||||
|
|
||||||
@ -142,32 +147,39 @@ class VideoFromFile(VideoInput):
|
|||||||
|
|
||||||
with av.open(self.__file, mode="r") as container:
|
with av.open(self.__file, mode="r") as container:
|
||||||
video_stream = self._get_first_video_stream(container)
|
video_stream = self._get_first_video_stream(container)
|
||||||
# 1. Prefer the frames field if available
|
# 1. Prefer the frames field if available and usable
|
||||||
if video_stream.frames and video_stream.frames > 0 and not self.__start_time and not self.__duration:
|
if (
|
||||||
|
video_stream.frames
|
||||||
|
and video_stream.frames > 0
|
||||||
|
and not self.__start_time
|
||||||
|
and not self.__duration
|
||||||
|
):
|
||||||
return int(video_stream.frames)
|
return int(video_stream.frames)
|
||||||
|
|
||||||
if self.__duration:
|
|
||||||
return self.__duration / float(video_stream.average_rate)
|
|
||||||
# 2. Try to estimate from duration and average_rate using only metadata
|
# 2. Try to estimate from duration and average_rate using only metadata
|
||||||
if container.duration is not None and video_stream.average_rate:
|
|
||||||
duration_seconds = float(container.duration / av.time_base) - self.__start_time
|
|
||||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
|
||||||
if estimated_frames > 0:
|
|
||||||
return estimated_frames
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
getattr(video_stream, "duration", None) is not None
|
getattr(video_stream, "duration", None) is not None
|
||||||
and getattr(video_stream, "time_base", None) is not None
|
and getattr(video_stream, "time_base", None) is not None
|
||||||
and video_stream.average_rate
|
and video_stream.average_rate
|
||||||
):
|
):
|
||||||
duration_seconds = float(video_stream.duration * video_stream.time_base) - self.start_time
|
raw_duration = float(video_stream.duration * video_stream.time_base)
|
||||||
|
if self.__start_time < 0:
|
||||||
|
duration_from_start = min(raw_duration, -self.__start_time)
|
||||||
|
else:
|
||||||
|
duration_from_start = raw_duration - self.__start_time
|
||||||
|
duration_seconds = min(self.__duration, duration_from_start)
|
||||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||||
if estimated_frames > 0:
|
if estimated_frames > 0:
|
||||||
return estimated_frames
|
return estimated_frames
|
||||||
|
|
||||||
# 3. Last resort: decode frames and count them (streaming)
|
# 3. Last resort: decode frames and count them (streaming)
|
||||||
|
if self.__start_time < 0:
|
||||||
|
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||||
|
else:
|
||||||
|
start_time = self.__start_time
|
||||||
frame_count = 1
|
frame_count = 1
|
||||||
start_pts = int(self.__start_time / video_stream.time_base)
|
start_pts = int(start_time / video_stream.time_base)
|
||||||
|
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||||
container.seek(start_pts, stream=video_stream)
|
container.seek(start_pts, stream=video_stream)
|
||||||
frame_iterator = (
|
frame_iterator = (
|
||||||
container.decode(video_stream)
|
container.decode(video_stream)
|
||||||
@ -180,6 +192,8 @@ class VideoFromFile(VideoInput):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
|
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
|
||||||
for frame in frame_iterator:
|
for frame in frame_iterator:
|
||||||
|
if frame.pts >= end_pts:
|
||||||
|
break
|
||||||
frame_count += 1
|
frame_count += 1
|
||||||
return frame_count
|
return frame_count
|
||||||
|
|
||||||
@ -219,11 +233,15 @@ class VideoFromFile(VideoInput):
|
|||||||
return container.format.name
|
return container.format.name
|
||||||
|
|
||||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||||
video_stream = container.streams.video[0]
|
video_stream = self._get_first_video_stream(container)
|
||||||
|
if self.__start_time < 0:
|
||||||
|
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||||
|
else:
|
||||||
|
start_time = self.__start_time
|
||||||
# Get video frames
|
# Get video frames
|
||||||
frames = []
|
frames = []
|
||||||
start_pts = int(self.__start_time / video_stream.time_base)
|
start_pts = int(start_time / video_stream.time_base)
|
||||||
end_pts = int((self.__start_time + self.__duration) / video_stream.time_base)
|
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||||
container.seek(start_pts, stream=video_stream)
|
container.seek(start_pts, stream=video_stream)
|
||||||
for frame in container.decode(video_stream):
|
for frame in container.decode(video_stream):
|
||||||
if frame.pts < start_pts:
|
if frame.pts < start_pts:
|
||||||
@ -248,20 +266,21 @@ class VideoFromFile(VideoInput):
|
|||||||
audio_frames = []
|
audio_frames = []
|
||||||
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
||||||
frames = itertools.chain.from_iterable(
|
frames = itertools.chain.from_iterable(
|
||||||
map(resample, container.decode(audio_stream)))
|
map(resample, container.decode(audio_stream))
|
||||||
|
)
|
||||||
|
|
||||||
has_first_frame = False
|
has_first_frame = False
|
||||||
for frame in frames:
|
for frame in frames:
|
||||||
offset_seconds = self.__start_time - frame.pts * audio_stream.time_base
|
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||||
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||||
if to_skip < frame.samples:
|
if to_skip < frame.samples:
|
||||||
has_first_frame = True
|
has_first_frame = True
|
||||||
break
|
break
|
||||||
if has_first_frame:
|
if has_first_frame:
|
||||||
audio_frames.append(frame.to_ndarray()[...,to_skip:])
|
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||||
|
|
||||||
for frame in frames:
|
for frame in frames:
|
||||||
if frame.time > self.__start_time + self.__duration:
|
if frame.time > start_time + self.__duration:
|
||||||
break
|
break
|
||||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||||
if len(audio_frames) > 0:
|
if len(audio_frames) > 0:
|
||||||
@ -290,7 +309,7 @@ class VideoFromFile(VideoInput):
|
|||||||
path: str | io.BytesIO,
|
path: str | io.BytesIO,
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if isinstance(self.__file, io.BytesIO):
|
if isinstance(self.__file, io.BytesIO):
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
@ -309,10 +328,7 @@ class VideoFromFile(VideoInput):
|
|||||||
components = self.get_components_internal(container)
|
components = self.get_components_internal(container)
|
||||||
video = VideoFromComponents(components)
|
video = VideoFromComponents(components)
|
||||||
return video.save_to(
|
return video.save_to(
|
||||||
path,
|
path, format=format, codec=codec, metadata=metadata
|
||||||
format=format,
|
|
||||||
codec=codec,
|
|
||||||
metadata=metadata
|
|
||||||
)
|
)
|
||||||
|
|
||||||
streams = container.streams
|
streams = container.streams
|
||||||
@ -346,15 +362,21 @@ class VideoFromFile(VideoInput):
|
|||||||
output_container.mux(packet)
|
output_container.mux(packet)
|
||||||
|
|
||||||
def _get_first_video_stream(self, container: InputContainer):
|
def _get_first_video_stream(self, container: InputContainer):
|
||||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
if len(container.streams.video):
|
||||||
if video_stream is None:
|
return container.streams.video[0]
|
||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
return video_stream
|
|
||||||
|
|
||||||
def as_trimmed(self, start_time: float=0, duration: float=0) -> VideoInput|None:
|
def as_trimmed(
|
||||||
if self.get_duration() < start_time + duration:
|
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
|
||||||
|
) -> VideoInput | None:
|
||||||
|
trimmed = VideoFromFile(
|
||||||
|
self.get_stream_source(),
|
||||||
|
start_time=start_time + self.__start_time,
|
||||||
|
duration=duration + self.__duration,
|
||||||
|
)
|
||||||
|
if trimmed.get_duration() < duration and strict_duration:
|
||||||
return None
|
return None
|
||||||
return VideoFromFile(self.get_stream_source(), start_time=start_time + self.__start_time, duration=duration + self.__duration)
|
return trimmed
|
||||||
|
|
||||||
|
|
||||||
class VideoFromComponents(VideoInput):
|
class VideoFromComponents(VideoInput):
|
||||||
@ -369,7 +391,7 @@ class VideoFromComponents(VideoInput):
|
|||||||
return VideoComponents(
|
return VideoComponents(
|
||||||
images=self.__components.images,
|
images=self.__components.images,
|
||||||
audio=self.__components.audio,
|
audio=self.__components.audio,
|
||||||
frame_rate=self.__components.frame_rate
|
frame_rate=self.__components.frame_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_to(
|
def save_to(
|
||||||
@ -377,7 +399,7 @@ class VideoFromComponents(VideoInput):
|
|||||||
path: str,
|
path: str,
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||||
raise ValueError("Only MP4 format is supported for now")
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
@ -428,7 +450,13 @@ class VideoFromComponents(VideoInput):
|
|||||||
|
|
||||||
# Flush encoder
|
# Flush encoder
|
||||||
output.mux(audio_stream.encode(None))
|
output.mux(audio_stream.encode(None))
|
||||||
def as_trimmed(self, start_time: float|None=None, duration: float|None=None) -> VideoInput|None:
|
|
||||||
|
def as_trimmed(
|
||||||
|
self,
|
||||||
|
start_time: float | None = None,
|
||||||
|
duration: float | None = None,
|
||||||
|
strict_duration: bool = True,
|
||||||
|
) -> VideoInput | None:
|
||||||
if self.get_duration() < start_time + duration:
|
if self.get_duration() < start_time + duration:
|
||||||
return None
|
return None
|
||||||
#TODO Consider tracking duration and trimming at time of save?
|
#TODO Consider tracking duration and trimming at time of save?
|
||||||
|
|||||||
@ -221,7 +221,7 @@ class VideoSlice(io.ComfyNode):
|
|||||||
"start_time",
|
"start_time",
|
||||||
default=0.0,
|
default=0.0,
|
||||||
max=1e5,
|
max=1e5,
|
||||||
min=0.0,
|
min=-1e5,
|
||||||
step=0.001,
|
step=0.001,
|
||||||
tooltip="Start time in seconds",
|
tooltip="Start time in seconds",
|
||||||
),
|
),
|
||||||
@ -230,7 +230,7 @@ class VideoSlice(io.ComfyNode):
|
|||||||
default=0.0,
|
default=0.0,
|
||||||
min=0.0,
|
min=0.0,
|
||||||
step=0.001,
|
step=0.001,
|
||||||
tooltip="Duration in seconds",
|
tooltip="Duration in seconds, or 0 for unlimited duration",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
@ -240,7 +240,7 @@ class VideoSlice(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, video, start_time, duration) -> io.NodeOutput:
|
def execute(cls, video, start_time, duration) -> io.NodeOutput:
|
||||||
trimmed = video.as_trimmed(start_time, duration)
|
trimmed = video.as_trimmed(start_time, duration, strict_duration=False)
|
||||||
if trimmed is not None:
|
if trimmed is not None:
|
||||||
return io.NodeOutput(trimmed)
|
return io.NodeOutput(trimmed)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user