feat: Add VideoSlice node with lazy operations on VideoInput
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

- Add VideoOp base class and SliceOp in _input/video_types.py
- Add sliced() method to VideoInput that returns a copy with operation appended
- Each subclass applies operations in get_components() and get_frame_count()
- After materialization, VideoFromFile delegates to internal VideoFromComponents
- Add VideoSlice node that uses video.sliced(start_frame, frame_count)
- Add tests for SliceOp, sliced() behavior, and materialization
This commit is contained in:
bymyself 2026-01-23 20:32:57 -08:00
parent d7f3241bf6
commit e4f3d335dc
6 changed files with 263 additions and 13 deletions

View File

@ -1,10 +1,12 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .video_types import VideoInput from .video_types import VideoInput, VideoOp, SliceOp
__all__ = [ __all__ = [
"ImageInput", "ImageInput",
"AudioInput", "AudioInput",
"VideoInput", "VideoInput",
"VideoOp",
"SliceOp",
"MaskInput", "MaskInput",
"LatentInput", "LatentInput",
] ]

View File

@ -1,11 +1,48 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from fractions import Fraction from fractions import Fraction
from typing import Optional, Union, IO from typing import Optional, Union, IO
import copy
import io import io
import av import av
from .._util import VideoContainer, VideoCodec, VideoComponents from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoOp(ABC):
"""Base class for lazy video operations."""
@abstractmethod
def apply(self, components: VideoComponents) -> VideoComponents:
pass
@abstractmethod
def compute_frame_count(self, input_frame_count: int) -> int:
pass
@dataclass(frozen=True)
class SliceOp(VideoOp):
"""Extract a range of frames from the video."""
start_frame: int
frame_count: int
def apply(self, components: VideoComponents) -> VideoComponents:
total = components.images.shape[0]
start = max(0, min(self.start_frame, total))
end = min(start + self.frame_count, total)
return VideoComponents(
images=components.images[start:end],
audio=components.audio,
frame_rate=components.frame_rate,
metadata=getattr(components, 'metadata', None),
)
def compute_frame_count(self, input_frame_count: int) -> int:
start = max(0, min(self.start_frame, input_frame_count))
return min(self.frame_count, input_frame_count - start)
class VideoInput(ABC): class VideoInput(ABC):
""" """
Abstract base class for video input types. Abstract base class for video input types.
@ -21,6 +58,12 @@ class VideoInput(ABC):
""" """
pass pass
def sliced(self, start_frame: int, frame_count: int) -> "VideoInput":
"""Return a copy of this video with a slice operation appended."""
new = copy.copy(self)
new._operations = getattr(self, '_operations', []) + [SliceOp(start_frame, frame_count)]
return new
@abstractmethod @abstractmethod
def save_to( def save_to(
self, self,

View File

@ -1,7 +1,8 @@
from .video_types import VideoFromFile, VideoFromComponents from .video_types import VideoFromFile, VideoFromComponents
from .._input import SliceOp
__all__ = [ __all__ = [
# Implementations
"VideoFromFile", "VideoFromFile",
"VideoFromComponents", "VideoFromComponents",
"SliceOp",
] ]

View File

@ -3,7 +3,7 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream from av.subtitles.stream import SubtitleStream
from fractions import Fraction from fractions import Fraction
from typing import Optional from typing import Optional
from .._input import AudioInput, VideoInput from .._input import AudioInput, VideoInput, VideoOp
import av import av
import io import io
import json import json
@ -63,6 +63,8 @@ class VideoFromFile(VideoInput):
containing the file contents. containing the file contents.
""" """
self.__file = file self.__file = file
self._operations: list[VideoOp] = []
self.__materialized: Optional[VideoFromComponents] = None
def get_stream_source(self) -> str | io.BytesIO: def get_stream_source(self) -> str | io.BytesIO:
""" """
@ -161,6 +163,10 @@ class VideoFromFile(VideoInput):
if frame_count == 0: if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'") raise ValueError(f"Could not determine frame count for file '{self.__file}'")
# Apply operations to get final frame count
for op in self._operations:
frame_count = op.compute_frame_count(frame_count)
return frame_count return frame_count
def get_frame_rate(self) -> Fraction: def get_frame_rate(self) -> Fraction:
@ -239,10 +245,18 @@ class VideoFromFile(VideoInput):
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents: def get_components(self) -> VideoComponents:
if self.__materialized is not None:
return self.__materialized.get_components()
if isinstance(self.__file, io.BytesIO): if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container: with av.open(self.__file, mode='r') as container:
return self.get_components_internal(container) components = self.get_components_internal(container)
for op in self._operations:
components = op.apply(components)
self.__materialized = VideoFromComponents(components)
self._operations = []
return components
raise ValueError(f"No video stream found in file '{self.__file}'") raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to( def save_to(
@ -317,14 +331,27 @@ class VideoFromComponents(VideoInput):
def __init__(self, components: VideoComponents): def __init__(self, components: VideoComponents):
self.__components = components self.__components = components
self._operations: list[VideoOp] = []
def get_components(self) -> VideoComponents: def get_components(self) -> VideoComponents:
if self._operations:
components = self.__components
for op in self._operations:
components = op.apply(components)
self.__components = components
self._operations = []
return VideoComponents( return VideoComponents(
images=self.__components.images, images=self.__components.images,
audio=self.__components.audio, audio=self.__components.audio,
frame_rate=self.__components.frame_rate frame_rate=self.__components.frame_rate
) )
def get_frame_count(self) -> int:
count = int(self.__components.images.shape[0])
for op in self._operations:
count = op.compute_frame_count(count)
return count
def save_to( def save_to(
self, self,
path: str, path: str,
@ -332,6 +359,9 @@ class VideoFromComponents(VideoInput):
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None
): ):
# Materialize ops before saving
components = self.get_components()
if format != VideoContainer.AUTO and format != VideoContainer.MP4: if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now") raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264: if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@ -345,22 +375,22 @@ class VideoFromComponents(VideoInput):
for key, value in metadata.items(): for key, value in metadata.items():
output.metadata[key] = json.dumps(value) output.metadata[key] = json.dumps(value)
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) frame_rate = Fraction(round(components.frame_rate * 1000), 1000)
# Create a video stream # Create a video stream
video_stream = output.add_stream('h264', rate=frame_rate) video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2] video_stream.width = components.images.shape[2]
video_stream.height = self.__components.images.shape[1] video_stream.height = components.images.shape[1]
video_stream.pix_fmt = 'yuv420p' video_stream.pix_fmt = 'yuv420p'
# Create an audio stream # Create an audio stream
audio_sample_rate = 1 audio_sample_rate = 1
audio_stream: Optional[av.AudioStream] = None audio_stream: Optional[av.AudioStream] = None
if self.__components.audio: if components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate']) audio_sample_rate = int(components.audio['sample_rate'])
audio_stream = output.add_stream('aac', rate=audio_sample_rate) audio_stream = output.add_stream('aac', rate=audio_sample_rate)
# Encode video # Encode video
for i, frame in enumerate(self.__components.images): for i, frame in enumerate(components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24') frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
@ -371,9 +401,9 @@ class VideoFromComponents(VideoInput):
packet = video_stream.encode(None) packet = video_stream.encode(None)
output.mux(packet) output.mux(packet)
if audio_stream and self.__components.audio: if audio_stream and components.audio:
waveform = self.__components.audio['waveform'] waveform = components.audio['waveform']
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * components.images.shape[0])]
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
frame.sample_rate = audio_sample_rate frame.sample_rate = audio_sample_rate
frame.pts = 0 frame.pts = 0

View File

@ -159,6 +159,29 @@ class GetVideoComponents(io.ComfyNode):
return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class VideoSlice(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoSlice",
display_name="Video Slice",
category="image/video",
description="Extract a range of frames from a video.",
inputs=[
io.Video.Input("video", tooltip="The video to slice."),
io.Int.Input("start_frame", default=0, min=0, tooltip="The frame index to start from (0-indexed)."),
io.Int.Input("frame_count", default=1, min=1, tooltip="Number of frames to extract."),
],
outputs=[
io.Video.Output(tooltip="The sliced video."),
],
)
@classmethod
def execute(cls, video: Input.Video, start_frame: int, frame_count: int) -> io.NodeOutput:
return io.NodeOutput(video.sliced(start_frame, frame_count))
class LoadVideo(io.ComfyNode): class LoadVideo(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -206,6 +229,7 @@ class VideoExtension(ComfyExtension):
SaveVideo, SaveVideo,
CreateVideo, CreateVideo,
GetVideoComponents, GetVideoComponents,
VideoSlice,
LoadVideo, LoadVideo,
] ]

View File

@ -0,0 +1,150 @@
import pytest
import torch
import tempfile
import os
import av
from fractions import Fraction
from comfy_api.input_impl.video_types import (
VideoFromFile,
VideoFromComponents,
SliceOp,
)
from comfy_api.util.video_types import VideoComponents
def create_test_video(width=4, height=4, frames=10, fps=30):
"""Helper to create a temporary video file."""
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with av.open(tmp.name, mode="w") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
for i in range(frames):
frame_data = torch.ones(height, width, 3, dtype=torch.uint8) * (i * 25)
frame = av.VideoFrame.from_ndarray(frame_data.numpy(), format="rgb24")
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
packet = stream.encode(None)
container.mux(packet)
return tmp.name
@pytest.fixture
def video_file_10_frames():
file_path = create_test_video(frames=10)
yield file_path
os.unlink(file_path)
@pytest.fixture
def video_components_10_frames():
images = torch.rand(10, 4, 4, 3)
return VideoComponents(images=images, frame_rate=Fraction(30))
class TestSliceOp:
def test_apply_slices_correctly(self, video_components_10_frames):
op = SliceOp(start_frame=2, frame_count=3)
result = op.apply(video_components_10_frames)
assert result.images.shape[0] == 3
assert torch.equal(result.images, video_components_10_frames.images[2:5])
def test_compute_frame_count(self):
op = SliceOp(start_frame=2, frame_count=5)
assert op.compute_frame_count(10) == 5
def test_compute_frame_count_clamps(self):
op = SliceOp(start_frame=8, frame_count=5)
assert op.compute_frame_count(10) == 2
class TestVideoSliced:
def test_sliced_returns_new_instance(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert video is not sliced
assert len(video._operations) == 0
assert len(sliced._operations) == 1
def test_get_components_applies_operations(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[2:5])
def test_get_frame_count(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert sliced.get_frame_count() == 3
def test_get_duration(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(0, 3)
assert sliced.get_duration() == pytest.approx(0.1)
def test_chained_slices_compose(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 6).sliced(1, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert torch.equal(components.images, video_components_10_frames.images[3:6])
def test_operations_list_is_immutable(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced1 = video.sliced(0, 5)
sliced2 = sliced1.sliced(1, 2)
assert len(video._operations) == 0
assert len(sliced1._operations) == 1
assert len(sliced2._operations) == 2
def test_from_file(self, video_file_10_frames):
video = VideoFromFile(video_file_10_frames)
sliced = video.sliced(2, 3)
components = sliced.get_components()
assert components.images.shape[0] == 3
assert sliced.get_frame_count() == 3
def test_save_sliced_video(self, video_components_10_frames, tmp_path):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
output_path = str(tmp_path / "sliced_output.mp4")
sliced.save_to(output_path)
saved_video = VideoFromFile(output_path)
assert saved_video.get_frame_count() == 3
def test_materialization_clears_ops(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
assert len(sliced._operations) == 1
sliced.get_components()
assert len(sliced._operations) == 0
def test_second_get_components_uses_cache(self, video_components_10_frames):
video = VideoFromComponents(video_components_10_frames)
sliced = video.sliced(2, 3)
first = sliced.get_components()
second = sliced.get_components()
assert first.images.shape == second.images.shape
assert torch.equal(first.images, second.images)