mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 07:52:35 +08:00
Base TrimVideo implementation
This commit is contained in:
parent
2129e7d278
commit
803808b1b1
@ -34,6 +34,16 @@ class VideoInput(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def as_trimmed(self, start_time: float|None=None, duration: float|None=None) -> VideoInput|None:
|
||||||
|
"""
|
||||||
|
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new VideoInput, or None if the result would have negative duration
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||||
"""
|
"""
|
||||||
Get a streamable source for the video. This allows processing without
|
Get a streamable source for the video. This allows processing without
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Optional
|
|||||||
from .._input import AudioInput, VideoInput
|
from .._input import AudioInput, VideoInput
|
||||||
import av
|
import av
|
||||||
import io
|
import io
|
||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
|
|||||||
formats = container_format.split(",")
|
formats = container_format.split(",")
|
||||||
return formats[0]
|
return formats[0]
|
||||||
|
|
||||||
|
|
||||||
def get_open_write_kwargs(
|
def get_open_write_kwargs(
|
||||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@ -57,12 +57,17 @@ class VideoFromFile(VideoInput):
|
|||||||
Class representing video input from a file.
|
Class representing video input from a file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, file: str | io.BytesIO):
|
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
|
||||||
"""
|
"""
|
||||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||||
containing the file contents.
|
containing the file contents.
|
||||||
"""
|
"""
|
||||||
self.__file = file
|
self.__file = file
|
||||||
|
self.__start_time = start_time
|
||||||
|
self.__duration = duration
|
||||||
|
if self.get_duration() < duration:
|
||||||
|
raise ValueError(f"Can not initialize video of negative duration:\nSource duration: {self.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}")
|
||||||
|
self.__duration = duration
|
||||||
|
|
||||||
def get_stream_source(self) -> str | io.BytesIO:
|
def get_stream_source(self) -> str | io.BytesIO:
|
||||||
"""
|
"""
|
||||||
@ -96,18 +101,20 @@ class VideoFromFile(VideoInput):
|
|||||||
Returns:
|
Returns:
|
||||||
Duration in seconds
|
Duration in seconds
|
||||||
"""
|
"""
|
||||||
|
if self.__duration:
|
||||||
|
return self.__duration
|
||||||
if isinstance(self.__file, io.BytesIO):
|
if isinstance(self.__file, io.BytesIO):
|
||||||
self.__file.seek(0)
|
self.__file.seek(0)
|
||||||
with av.open(self.__file, mode="r") as container:
|
with av.open(self.__file, mode="r") as container:
|
||||||
if container.duration is not None:
|
if container.duration is not None:
|
||||||
return float(container.duration / av.time_base)
|
return float(container.duration / av.time_base) - self.__start_time
|
||||||
|
|
||||||
# Fallback: calculate from frame count and frame rate
|
# Fallback: calculate from frame count and frame rate
|
||||||
video_stream = next(
|
video_stream = next(
|
||||||
(s for s in container.streams if s.type == "video"), None
|
(s for s in container.streams if s.type == "video"), None
|
||||||
)
|
)
|
||||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||||
return float(video_stream.frames / video_stream.average_rate)
|
return float(video_stream.frames / video_stream.average_rate) - self.start_time
|
||||||
|
|
||||||
# Last resort: decode frames to count them
|
# Last resort: decode frames to count them
|
||||||
if video_stream and video_stream.average_rate:
|
if video_stream and video_stream.average_rate:
|
||||||
@ -117,7 +124,7 @@ class VideoFromFile(VideoInput):
|
|||||||
for _ in packet.decode():
|
for _ in packet.decode():
|
||||||
frame_count += 1
|
frame_count += 1
|
||||||
if frame_count > 0:
|
if frame_count > 0:
|
||||||
return float(frame_count / video_stream.average_rate)
|
return float(frame_count / video_stream.average_rate) - self.start_time
|
||||||
|
|
||||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||||
|
|
||||||
@ -132,12 +139,14 @@ class VideoFromFile(VideoInput):
|
|||||||
with av.open(self.__file, mode="r") as container:
|
with av.open(self.__file, mode="r") as container:
|
||||||
video_stream = self._get_first_video_stream(container)
|
video_stream = self._get_first_video_stream(container)
|
||||||
# 1. Prefer the frames field if available
|
# 1. Prefer the frames field if available
|
||||||
if video_stream.frames and video_stream.frames > 0:
|
if video_stream.frames and video_stream.frames > 0 and not self.__start_time and not self.__duration:
|
||||||
return int(video_stream.frames)
|
return int(video_stream.frames)
|
||||||
|
|
||||||
|
if self.__duration:
|
||||||
|
return self.__duration / float(video_stream.average_rate)
|
||||||
# 2. Try to estimate from duration and average_rate using only metadata
|
# 2. Try to estimate from duration and average_rate using only metadata
|
||||||
if container.duration is not None and video_stream.average_rate:
|
if container.duration is not None and video_stream.average_rate:
|
||||||
duration_seconds = float(container.duration / av.time_base)
|
duration_seconds = float(container.duration / av.time_base) - self.__start_time
|
||||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||||
if estimated_frames > 0:
|
if estimated_frames > 0:
|
||||||
return estimated_frames
|
return estimated_frames
|
||||||
@ -147,20 +156,23 @@ class VideoFromFile(VideoInput):
|
|||||||
and getattr(video_stream, "time_base", None) is not None
|
and getattr(video_stream, "time_base", None) is not None
|
||||||
and video_stream.average_rate
|
and video_stream.average_rate
|
||||||
):
|
):
|
||||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
duration_seconds = float(video_stream.duration * video_stream.time_base) - self.start_time
|
||||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||||
if estimated_frames > 0:
|
if estimated_frames > 0:
|
||||||
return estimated_frames
|
return estimated_frames
|
||||||
|
|
||||||
# 3. Last resort: decode frames and count them (streaming)
|
# 3. Last resort: decode frames and count them (streaming)
|
||||||
frame_count = 0
|
frame_count = 1
|
||||||
container.seek(0)
|
start_pts = int(self.__start_time / video_stream.time_base)
|
||||||
for packet in container.demux(video_stream):
|
container.seek(start_pts, stream=video_stream)
|
||||||
for _ in packet.decode():
|
frame_iterator = container.decode(video_stream)
|
||||||
frame_count += 1
|
for frame in frame_iterator:
|
||||||
|
if frame.pts >= start_pts:
|
||||||
if frame_count == 0:
|
break
|
||||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
else:
|
||||||
|
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
|
||||||
|
for frame in frame_iterator:
|
||||||
|
frame_count += 1
|
||||||
return frame_count
|
return frame_count
|
||||||
|
|
||||||
def get_frame_rate(self) -> Fraction:
|
def get_frame_rate(self) -> Fraction:
|
||||||
@ -199,9 +211,17 @@ class VideoFromFile(VideoInput):
|
|||||||
return container.format.name
|
return container.format.name
|
||||||
|
|
||||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||||
|
video_stream = container.streams.video[0]
|
||||||
# Get video frames
|
# Get video frames
|
||||||
frames = []
|
frames = []
|
||||||
for frame in container.decode(video=0):
|
start_pts = int(self.__start_time / video_stream.time_base)
|
||||||
|
end_pts = int((self.__start_time + self.__duration) / video_stream.time_base)
|
||||||
|
container.seek(start_pts, stream=video_stream)
|
||||||
|
for frame in container.decode(video_stream):
|
||||||
|
if frame.pts < start_pts:
|
||||||
|
continue
|
||||||
|
if self.__duration and frame.pts >= end_pts:
|
||||||
|
break
|
||||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||||
frames.append(img)
|
frames.append(img)
|
||||||
@ -209,31 +229,43 @@ class VideoFromFile(VideoInput):
|
|||||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||||
|
|
||||||
# Get frame rate
|
# Get frame rate
|
||||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
||||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
|
||||||
|
|
||||||
# Get audio if available
|
# Get audio if available
|
||||||
audio = None
|
audio = None
|
||||||
try:
|
container.seek(start_pts, stream=video_stream)
|
||||||
container.seek(0) # Reset the container to the beginning
|
# Use last stream for consistency
|
||||||
for stream in container.streams:
|
audio_stream = container.streams.audio[-1]
|
||||||
if stream.type != 'audio':
|
if audio_stream:
|
||||||
continue
|
audio_frames = []
|
||||||
assert isinstance(stream, av.AudioStream)
|
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
||||||
audio_frames = []
|
frames = itertools.chain.from_iterable(
|
||||||
for packet in container.demux(stream):
|
map(resample, container.decode(audio_stream)))
|
||||||
for frame in packet.decode():
|
|
||||||
assert isinstance(frame, av.AudioFrame)
|
has_first_frame = False
|
||||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
for frame in frames:
|
||||||
if len(audio_frames) > 0:
|
offset_seconds = self.__start_time - frame.pts * audio_stream.time_base
|
||||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
if to_skip < frame.samples:
|
||||||
audio = AudioInput({
|
has_first_frame = True
|
||||||
"waveform": audio_tensor,
|
break
|
||||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
if has_first_frame:
|
||||||
})
|
audio_frames.append(frame.to_ndarray()[...,to_skip:])
|
||||||
except StopIteration:
|
|
||||||
pass # No audio stream
|
for frame in frames:
|
||||||
|
if frame.time > self.__start_time + self.__duration:
|
||||||
|
break
|
||||||
|
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||||
|
if len(audio_frames) > 0:
|
||||||
|
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||||
|
if self.__duration:
|
||||||
|
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
||||||
|
|
||||||
|
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||||
|
audio = AudioInput({
|
||||||
|
"waveform": audio_tensor,
|
||||||
|
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
||||||
|
})
|
||||||
|
|
||||||
metadata = container.metadata
|
metadata = container.metadata
|
||||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||||
@ -262,6 +294,8 @@ class VideoFromFile(VideoInput):
|
|||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
|
if self.__start_time or self.__duration:
|
||||||
|
reuse_streams = False
|
||||||
|
|
||||||
if not reuse_streams:
|
if not reuse_streams:
|
||||||
components = self.get_components_internal(container)
|
components = self.get_components_internal(container)
|
||||||
@ -309,6 +343,11 @@ class VideoFromFile(VideoInput):
|
|||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
return video_stream
|
return video_stream
|
||||||
|
|
||||||
|
def as_trimmed(self, start_time: float=0, duration: float=0) -> VideoInput|None:
|
||||||
|
if self.get_duration() < start_time + duration:
|
||||||
|
return None
|
||||||
|
return VideoFromFile(self.get_stream_source(), start_time=start_time + self.__start_time, duration=duration + self.__duration)
|
||||||
|
|
||||||
|
|
||||||
class VideoFromComponents(VideoInput):
|
class VideoFromComponents(VideoInput):
|
||||||
"""
|
"""
|
||||||
@ -381,3 +420,8 @@ class VideoFromComponents(VideoInput):
|
|||||||
|
|
||||||
# Flush encoder
|
# Flush encoder
|
||||||
output.mux(audio_stream.encode(None))
|
output.mux(audio_stream.encode(None))
|
||||||
|
def as_trimmed(self, start_time: float|None=None, duration: float|None=None) -> VideoInput|None:
|
||||||
|
if self.get_duration() < start_time + duration:
|
||||||
|
return None
|
||||||
|
#TODO Consider tracking duration and trimming at time of save?
|
||||||
|
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
|
||||||
|
|||||||
@ -202,6 +202,28 @@ class LoadVideo(io.ComfyNode):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
class VideoSlice(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Video Slice",
|
||||||
|
display_name="Video Slice",
|
||||||
|
search_aliases=["trim video duration", "skip first frames", "frame load cap", "start time"],
|
||||||
|
category="image/video",
|
||||||
|
inputs=[
|
||||||
|
io.Video.Input('video'),
|
||||||
|
io.Float.Input('start_time', default=0.0, min=0.0, step=.001),
|
||||||
|
io.Float.Input('duration', default=0.0, min=0.0, step=.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Video.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, video, start_time, duration) -> io.NodeOutput:
|
||||||
|
return io.NodeOutput(video.as_trimmed(start_time, duration))
|
||||||
|
|
||||||
|
|
||||||
class VideoExtension(ComfyExtension):
|
class VideoExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
@ -212,6 +234,7 @@ class VideoExtension(ComfyExtension):
|
|||||||
CreateVideo,
|
CreateVideo,
|
||||||
GetVideoComponents,
|
GetVideoComponents,
|
||||||
LoadVideo,
|
LoadVideo,
|
||||||
|
VideoSlice,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> VideoExtension:
|
async def comfy_entrypoint() -> VideoExtension:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user