From c7d29a42ba2b6ff3161db850c89cf6771f0f90d3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jun 2026 12:19:06 +0800 Subject: [PATCH] refactor(video dataset): lazy video loading and frame-selective decode - Replace eager load_video_frames() with _decode_selected_frames() that opens the container with `with av.open(...)` (no resource leak) and decodes only the requested frame indices. - Video loader nodes now emit lazy VideoFromFile references; sampling and temporal-crop nodes operate lazily / decode only selected frames. --- comfy_extras/nodes_dataset.py | 227 +++++++++++++--------------------- 1 file changed, 85 insertions(+), 142 deletions(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 089d37dbc..d70a14ccc 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -10,7 +10,7 @@ from typing_extensions import override import folder_paths import node_helpers -from comfy_api.latest import ComfyExtension, io +from comfy_api.latest import ComfyExtension, io, Input, InputImpl, Types def load_and_process_images(image_files, input_dir): @@ -46,44 +46,33 @@ def load_and_process_images(image_files, input_dir): VALID_VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"] -def load_video_frames(video_path, max_frames=0, frame_stride=1, start_frame=0): - """Load video file and return frames as a tensor. +def _decode_selected_frames(video: Input.Video, indices: list[int]) -> Input.Video: + """Decode only the requested frame indices from a video. - Args: - video_path: Path to the video file - max_frames: Maximum number of frames to load (0 = all) - frame_stride: Sample every Nth frame - start_frame: Frame index to start from - - Returns: - torch.Tensor: Video frames as [T, H, W, C] float32 tensor in [0, 1] + Opens the underlying container once, decodes frames in presentation order, + keeps only the ones whose index is in ``indices``, and returns the result + wrapped in a VideoFromComponents so it still satisfies the VideoInput + contract for downstream nodes. """ - container = av.open(video_path) - stream = container.streams.video[0] + indices_sorted = sorted(set(indices)) + max_idx = indices_sorted[-1] + source = video.get_stream_source() - frames = [] - frame_idx = 0 - for frame in container.decode(stream): - if frame_idx < start_frame: - frame_idx += 1 - continue - if (frame_idx - start_frame) % frame_stride != 0: - frame_idx += 1 - continue - if max_frames > 0 and len(frames) >= max_frames: - break + frames_by_idx: dict[int, torch.Tensor] = {} + with av.open(source, mode="r") as container: + stream = container.streams.video[0] + wanted = set(indices_sorted) + for frame_idx, frame in enumerate(container.decode(stream)): + if frame_idx in wanted: + img = frame.to_ndarray(format="rgb24") + frames_by_idx[frame_idx] = torch.from_numpy(img.copy()).float() / 255.0 + if frame_idx >= max_idx: + break - img = frame.to_ndarray(format='rgb24') - img_tensor = torch.from_numpy(img.copy()).float() / 255.0 - frames.append(img_tensor) - frame_idx += 1 - - container.close() - - if not frames: - raise ValueError(f"No frames could be extracted from {video_path}") - - return torch.stack(frames) # [T, H, W, C] + stacked = torch.stack([frames_by_idx[i] for i in indices]) + return InputImpl.VideoFromComponents( + Types.VideoComponents(images=stacked, frame_rate=video.get_frame_rate()) + ) class LoadImageDataSetFromFolderNode(io.ComfyNode): @@ -211,39 +200,18 @@ class LoadVideoDataSetFromFolderNode(io.ComfyNode): options=folder_paths.get_input_subfolders(), tooltip="The folder containing video files.", ), - io.Int.Input( - "max_frames", - default=0, - min=0, - max=99999, - tooltip="Maximum frames to load per video (0 = all frames).", - ), - io.Int.Input( - "frame_stride", - default=1, - min=1, - max=1000, - tooltip="Sample every Nth frame (1 = every frame).", - ), - io.Int.Input( - "start_frame", - default=0, - min=0, - max=99999, - tooltip="Frame index to start loading from.", - ), ], outputs=[ - io.Image.Output( + io.Video.Output( display_name="videos", is_output_list=True, - tooltip="List of video tensors, each [T, H, W, C].", + tooltip="Lazy video references; frames are decoded only when needed downstream.", ), ], ) @classmethod - def execute(cls, folder, max_frames, frame_stride, start_frame): + def execute(cls, folder): sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) video_files = sorted([ f for f in os.listdir(sub_input_dir) @@ -253,15 +221,9 @@ class LoadVideoDataSetFromFolderNode(io.ComfyNode): if not video_files: raise ValueError(f"No video files found in {sub_input_dir}") - output_videos = [] - for file in video_files: - video_path = os.path.join(sub_input_dir, file) - frames = load_video_frames(video_path, max_frames, frame_stride, start_frame) - output_videos.append(frames) - logging.info(f"Loaded {file}: {frames.shape[0]} frames, {frames.shape[1]}x{frames.shape[2]}") - - logging.info(f"Loaded {len(output_videos)} videos from {sub_input_dir}") - return io.NodeOutput(output_videos) + videos = [InputImpl.VideoFromFile(os.path.join(sub_input_dir, f)) for f in video_files] + logging.info(f"Loaded {len(videos)} lazy video references from {sub_input_dir}") + return io.NodeOutput(videos) class LoadVideoTextDataSetFromFolderNode(io.ComfyNode): @@ -278,33 +240,12 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode): options=folder_paths.get_input_subfolders(), tooltip="The folder containing video files and .txt captions.", ), - io.Int.Input( - "max_frames", - default=0, - min=0, - max=99999, - tooltip="Maximum frames to load per video (0 = all frames).", - ), - io.Int.Input( - "frame_stride", - default=1, - min=1, - max=1000, - tooltip="Sample every Nth frame (1 = every frame).", - ), - io.Int.Input( - "start_frame", - default=0, - min=0, - max=99999, - tooltip="Frame index to start loading from.", - ), ], outputs=[ - io.Image.Output( + io.Video.Output( display_name="videos", is_output_list=True, - tooltip="List of video tensors, each [T, H, W, C].", + tooltip="Lazy video references; frames are decoded only when needed downstream.", ), io.String.Output( display_name="texts", @@ -315,7 +256,7 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode): ) @classmethod - def execute(cls, folder, max_frames, frame_stride, start_frame): + def execute(cls, folder): sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) video_files = [] @@ -337,7 +278,6 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode): if not video_files: raise ValueError(f"No video files found in {sub_input_dir}") - # Load captions (same name as video but .txt) captions = [] for vf in video_files: caption_path = os.path.splitext(vf)[0] + ".txt" @@ -347,14 +287,9 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode): else: captions.append("") - # Load videos - output_videos = [] - for vf in video_files: - frames = load_video_frames(vf, max_frames, frame_stride, start_frame) - output_videos.append(frames) - - logging.info(f"Loaded {len(output_videos)} videos with captions from {sub_input_dir}") - return io.NodeOutput(output_videos, captions) + videos = [InputImpl.VideoFromFile(vf) for vf in video_files] + logging.info(f"Loaded {len(videos)} lazy video references with captions from {sub_input_dir}") + return io.NodeOutput(videos, captions) def save_images_to_folder(image_list, output_dir, prefix="image"): @@ -1088,7 +1023,12 @@ class ShuffleImageTextDatasetNode(io.ComfyNode): class VideoFrameSampleNode(io.ComfyNode): - """Sample a fixed number of frames from a video using various strategies.""" + """Sample a fixed number of frames from a video using various strategies. + + For contiguous strategies ("head"/"tail") the result is a fully lazy + VideoInput (no frames decoded). For non-contiguous strategies + ("uniform"/"random") only the selected indices are decoded. + """ @classmethod def define_schema(cls): @@ -1098,7 +1038,7 @@ class VideoFrameSampleNode(io.ComfyNode): category="dataset/video", is_experimental=True, inputs=[ - io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."), + io.Video.Input("video", tooltip="Input video."), io.Int.Input( "num_frames", default=16, @@ -1121,20 +1061,27 @@ class VideoFrameSampleNode(io.ComfyNode): ), ], outputs=[ - io.Image.Output(display_name="video", tooltip="Sampled video [N, H, W, C]."), + io.Video.Output(display_name="video", tooltip="Sampled video."), ], ) @classmethod def execute(cls, video, num_frames, strategy, seed): - total_frames = video.shape[0] + total_frames = video.get_frame_count() num_frames = min(num_frames, total_frames) + fps = float(video.get_frame_rate()) if strategy == "head": - indices = list(range(num_frames)) - elif strategy == "tail": - indices = list(range(total_frames - num_frames, total_frames)) - elif strategy == "uniform": + return io.NodeOutput( + video.as_trimmed(0.0, num_frames / fps, strict_duration=False) + ) + if strategy == "tail": + start_t = (total_frames - num_frames) / fps + return io.NodeOutput( + video.as_trimmed(start_t, num_frames / fps, strict_duration=False) + ) + + if strategy == "uniform": if num_frames == 1: indices = [total_frames // 2] else: @@ -1145,11 +1092,11 @@ class VideoFrameSampleNode(io.ComfyNode): else: raise ValueError(f"Unknown strategy: {strategy}") - return io.NodeOutput(video[indices]) + return io.NodeOutput(_decode_selected_frames(video, indices)) class VideoTemporalCropNode(io.ComfyNode): - """Crop a continuous range of frames from a video.""" + """Crop a continuous range of frames from a video (fully lazy).""" @classmethod def define_schema(cls): @@ -1159,7 +1106,7 @@ class VideoTemporalCropNode(io.ComfyNode): category="dataset/video", is_experimental=True, inputs=[ - io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."), + io.Video.Input("video", tooltip="Input video."), io.Int.Input( "start_frame", default=0, @@ -1176,20 +1123,23 @@ class VideoTemporalCropNode(io.ComfyNode): ), ], outputs=[ - io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."), + io.Video.Output(display_name="video", tooltip="Cropped video (lazy)."), ], ) @classmethod def execute(cls, video, start_frame, length): - total_frames = video.shape[0] + total_frames = video.get_frame_count() + fps = float(video.get_frame_rate()) start_frame = min(start_frame, max(total_frames - 1, 0)) - end_frame = min(start_frame + length, total_frames) - return io.NodeOutput(video[start_frame:end_frame]) + length = min(length, total_frames - start_frame) + return io.NodeOutput( + video.as_trimmed(start_frame / fps, length / fps, strict_duration=False) + ) class VideoRandomTemporalCropNode(io.ComfyNode): - """Randomly crop a continuous range of frames from a video (for data augmentation).""" + """Randomly crop a continuous range of frames from a video (fully lazy).""" @classmethod def define_schema(cls): @@ -1199,7 +1149,7 @@ class VideoRandomTemporalCropNode(io.ComfyNode): category="dataset/video", is_experimental=True, inputs=[ - io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."), + io.Video.Input("video", tooltip="Input video."), io.Int.Input( "length", default=16, @@ -1216,49 +1166,42 @@ class VideoRandomTemporalCropNode(io.ComfyNode): ), ], outputs=[ - io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."), + io.Video.Output(display_name="video", tooltip="Cropped video (lazy)."), ], ) @classmethod def execute(cls, video, length, seed): - total_frames = video.shape[0] + total_frames = video.get_frame_count() + fps = float(video.get_frame_rate()) length = min(length, total_frames) max_start = total_frames - length rng = np.random.RandomState(seed % (2**32 - 1)) start = rng.randint(0, max_start + 1) if max_start > 0 else 0 - return io.NodeOutput(video[start:start + length]) + return io.NodeOutput( + video.as_trimmed(start / fps, length / fps, strict_duration=False) + ) -class ShuffleVideoDatasetNode(ImageProcessingNode): +class ShuffleVideoDatasetNode(io.ComfyNode): """Randomly shuffle the order of videos in the dataset.""" - node_id = "ShuffleVideoDataset" - display_name = "Shuffle Video Dataset" - description = "Randomly shuffle the order of videos in the dataset." - is_group_process = True - extra_inputs = [ - io.Int.Input( - "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." - ), - ] - @classmethod def define_schema(cls): return io.Schema( - node_id=cls.node_id, - display_name=cls.display_name, + node_id="ShuffleVideoDataset", + display_name="Shuffle Video Dataset", category="dataset/video", is_experimental=True, is_input_list=True, inputs=[ - io.Image.Input("images", tooltip="List of videos to shuffle."), + io.Video.Input("videos", tooltip="List of videos to shuffle."), io.Int.Input( "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." ), ], outputs=[ - io.Image.Output( + io.Video.Output( display_name="videos", is_output_list=True, tooltip="Shuffled videos", @@ -1267,11 +1210,11 @@ class ShuffleVideoDatasetNode(ImageProcessingNode): ) @classmethod - def execute(cls, images, seed): + def execute(cls, videos, seed): seed = seed[0] if isinstance(seed, list) else seed np.random.seed(seed % (2**32 - 1)) - indices = np.random.permutation(len(images)) - return io.NodeOutput([images[i] for i in indices]) + indices = np.random.permutation(len(videos)) + return io.NodeOutput([videos[i] for i in indices]) class ShuffleVideoTextDatasetNode(io.ComfyNode): @@ -1286,7 +1229,7 @@ class ShuffleVideoTextDatasetNode(io.ComfyNode): is_experimental=True, is_input_list=True, inputs=[ - io.Image.Input("videos", tooltip="List of videos to shuffle."), + io.Video.Input("videos", tooltip="List of videos to shuffle."), io.String.Input("texts", tooltip="List of texts to shuffle."), io.Int.Input( "seed", @@ -1297,7 +1240,7 @@ class ShuffleVideoTextDatasetNode(io.ComfyNode): ), ], outputs=[ - io.Image.Output( + io.Video.Output( display_name="videos", is_output_list=True, tooltip="Shuffled videos", @@ -1976,7 +1919,7 @@ class DatasetExtension(ComfyExtension): AdjustContrastNode, ShuffleDatasetNode, ShuffleImageTextDatasetNode, - # Video processing nodes + # Video processing nodes (lazy VideoInput in/out) VideoFrameSampleNode, VideoTemporalCropNode, VideoRandomTemporalCropNode,