From d364d3f8b5cc2641214d1c2984b0b97b12708e82 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:39:56 +0800 Subject: [PATCH] Init implementation of video dataset nodes --- comfy_extras/nodes_dataset.py | 476 +++++++++++++++++++++++++++++++++- 1 file changed, 473 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 98ed25d7e..089d37dbc 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -2,6 +2,7 @@ import logging import os import json +import av import numpy as np import torch from PIL import Image @@ -42,6 +43,49 @@ def load_and_process_images(image_files, input_dir): return output_images +VALID_VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"] + + +def load_video_frames(video_path, max_frames=0, frame_stride=1, start_frame=0): + """Load video file and return frames as a tensor. + + Args: + video_path: Path to the video file + max_frames: Maximum number of frames to load (0 = all) + frame_stride: Sample every Nth frame + start_frame: Frame index to start from + + Returns: + torch.Tensor: Video frames as [T, H, W, C] float32 tensor in [0, 1] + """ + container = av.open(video_path) + stream = container.streams.video[0] + + frames = [] + frame_idx = 0 + for frame in container.decode(stream): + if frame_idx < start_frame: + frame_idx += 1 + continue + if (frame_idx - start_frame) % frame_stride != 0: + frame_idx += 1 + continue + if max_frames > 0 and len(frames) >= max_frames: + break + + img = frame.to_ndarray(format='rgb24') + img_tensor = torch.from_numpy(img.copy()).float() / 255.0 + frames.append(img_tensor) + frame_idx += 1 + + container.close() + + if not frames: + raise ValueError(f"No frames could be extracted from {video_path}") + + return torch.stack(frames) # [T, H, W, C] + + class LoadImageDataSetFromFolderNode(io.ComfyNode): @classmethod def define_schema(cls): @@ -153,6 +197,166 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode): return io.NodeOutput(output_tensor, captions) +class LoadVideoDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadVideoDataSetFromFolder", + display_name="Load Video Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder containing video files.", + ), + io.Int.Input( + "max_frames", + default=0, + min=0, + max=99999, + tooltip="Maximum frames to load per video (0 = all frames).", + ), + io.Int.Input( + "frame_stride", + default=1, + min=1, + max=1000, + tooltip="Sample every Nth frame (1 = every frame).", + ), + io.Int.Input( + "start_frame", + default=0, + min=0, + max=99999, + tooltip="Frame index to start loading from.", + ), + ], + outputs=[ + io.Image.Output( + display_name="videos", + is_output_list=True, + tooltip="List of video tensors, each [T, H, W, C].", + ), + ], + ) + + @classmethod + def execute(cls, folder, max_frames, frame_stride, start_frame): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + video_files = sorted([ + f for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS) + ]) + + if not video_files: + raise ValueError(f"No video files found in {sub_input_dir}") + + output_videos = [] + for file in video_files: + video_path = os.path.join(sub_input_dir, file) + frames = load_video_frames(video_path, max_frames, frame_stride, start_frame) + output_videos.append(frames) + logging.info(f"Loaded {file}: {frames.shape[0]} frames, {frames.shape[1]}x{frames.shape[2]}") + + logging.info(f"Loaded {len(output_videos)} videos from {sub_input_dir}") + return io.NodeOutput(output_videos) + + +class LoadVideoTextDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadVideoTextDataSetFromFolder", + display_name="Load Video and Text Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder containing video files and .txt captions.", + ), + io.Int.Input( + "max_frames", + default=0, + min=0, + max=99999, + tooltip="Maximum frames to load per video (0 = all frames).", + ), + io.Int.Input( + "frame_stride", + default=1, + min=1, + max=1000, + tooltip="Sample every Nth frame (1 = every frame).", + ), + io.Int.Input( + "start_frame", + default=0, + min=0, + max=99999, + tooltip="Frame index to start loading from.", + ), + ], + outputs=[ + io.Image.Output( + display_name="videos", + is_output_list=True, + tooltip="List of video tensors, each [T, H, W, C].", + ), + io.String.Output( + display_name="texts", + is_output_list=True, + tooltip="List of text captions.", + ), + ], + ) + + @classmethod + def execute(cls, folder, max_frames, frame_stride, start_frame): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + + video_files = [] + for item in sorted(os.listdir(sub_input_dir)): + path = os.path.join(sub_input_dir, item) + if any(item.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS): + video_files.append(path) + elif os.path.isdir(path): + # Support kohya-ss/sd-scripts folder structure: {repeat}_{desc}/ + repeat = 1 + if item.split("_")[0].isdigit(): + repeat = int(item.split("_")[0]) + video_files.extend([ + os.path.join(path, f) + for f in sorted(os.listdir(path)) + if any(f.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS) + ] * repeat) + + if not video_files: + raise ValueError(f"No video files found in {sub_input_dir}") + + # Load captions (same name as video but .txt) + captions = [] + for vf in video_files: + caption_path = os.path.splitext(vf)[0] + ".txt" + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + captions.append(f.read().strip()) + else: + captions.append("") + + # Load videos + output_videos = [] + for vf in video_files: + frames = load_video_frames(vf, max_frames, frame_stride, start_frame) + output_videos.append(frames) + + logging.info(f"Loaded {len(output_videos)} videos with captions from {sub_input_dir}") + return io.NodeOutput(output_videos, captions) + + def save_images_to_folder(image_list, output_dir, prefix="image"): """Utility function to save a list of image tensors to disk. @@ -418,7 +622,15 @@ class ImageProcessingNode(io.ComfyNode): @classmethod def execute(cls, images, **kwargs): - """Execute the node. Routes to _process or _group_process based on mode.""" + """Execute the node. Routes to _process or _group_process based on mode. + + For individual processing (_process), automatically handles multi-frame + inputs (video tensors [T, H, W, C]) by applying _process per-frame and + concatenating the results. This allows all spatial transform nodes to + work with video without modification. Nodes that natively handle batched + tensors (e.g. pure tensor math) can set per_frame_process = False to + skip the per-frame loop. + """ is_group = cls._detect_processing_mode() # Extract scalar values from lists for parameters @@ -434,7 +646,16 @@ class ImageProcessingNode(io.ComfyNode): result = cls._group_process(images, **params) else: # Individual processing: images is single item, call _process - result = cls._process(images, **params) + # Auto-loop over frames for multi-frame inputs (video [T, H, W, C]) + # so that PIL-based spatial transforms work per-frame automatically. + if images.shape[0] > 1 and getattr(cls, 'per_frame_process', True): + results = [] + for i in range(images.shape[0]): + frame_result = cls._process(images[i:i + 1], **params) + results.append(frame_result) + result = torch.cat(results, dim=0) + else: + result = cls._process(images, **params) return io.NodeOutput(result) @@ -736,6 +957,7 @@ class NormalizeImagesNode(ImageProcessingNode): node_id = "NormalizeImages" display_name = "Normalize Images" description = "Normalize images using mean and standard deviation." + per_frame_process = False # Pure tensor math, handles any batch size extra_inputs = [ io.Float.Input( "mean", @@ -764,6 +986,7 @@ class AdjustBrightnessNode(ImageProcessingNode): node_id = "AdjustBrightness" display_name = "Adjust Brightness" description = "Adjust brightness of all images." + per_frame_process = False # Pure tensor math, handles any batch size extra_inputs = [ io.Float.Input( "factor", @@ -783,6 +1006,7 @@ class AdjustContrastNode(ImageProcessingNode): node_id = "AdjustContrast" display_name = "Adjust Contrast" description = "Adjust contrast of all images." + per_frame_process = False # Pure tensor math, handles any batch size extra_inputs = [ io.Float.Input( "factor", @@ -860,6 +1084,243 @@ class ShuffleImageTextDatasetNode(io.ComfyNode): return io.NodeOutput(shuffled_images, shuffled_texts) +# ========== Video Processing Nodes ========== + + +class VideoFrameSampleNode(io.ComfyNode): + """Sample a fixed number of frames from a video using various strategies.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VideoFrameSample", + display_name="Video Frame Sample", + category="dataset/video", + is_experimental=True, + inputs=[ + io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."), + io.Int.Input( + "num_frames", + default=16, + min=1, + max=9999, + tooltip="Number of frames to sample.", + ), + io.Combo.Input( + "strategy", + options=["uniform", "head", "tail", "random"], + default="uniform", + tooltip="uniform: evenly spaced, head: first N, tail: last N, random: random sorted.", + ), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed (only used with 'random' strategy).", + ), + ], + outputs=[ + io.Image.Output(display_name="video", tooltip="Sampled video [N, H, W, C]."), + ], + ) + + @classmethod + def execute(cls, video, num_frames, strategy, seed): + total_frames = video.shape[0] + num_frames = min(num_frames, total_frames) + + if strategy == "head": + indices = list(range(num_frames)) + elif strategy == "tail": + indices = list(range(total_frames - num_frames, total_frames)) + elif strategy == "uniform": + if num_frames == 1: + indices = [total_frames // 2] + else: + indices = [round(i * (total_frames - 1) / (num_frames - 1)) for i in range(num_frames)] + elif strategy == "random": + rng = np.random.RandomState(seed % (2**32 - 1)) + indices = sorted(rng.choice(total_frames, size=num_frames, replace=False).tolist()) + else: + raise ValueError(f"Unknown strategy: {strategy}") + + return io.NodeOutput(video[indices]) + + +class VideoTemporalCropNode(io.ComfyNode): + """Crop a continuous range of frames from a video.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VideoTemporalCrop", + display_name="Video Temporal Crop", + category="dataset/video", + is_experimental=True, + inputs=[ + io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."), + io.Int.Input( + "start_frame", + default=0, + min=0, + max=99999, + tooltip="Starting frame index.", + ), + io.Int.Input( + "length", + default=16, + min=1, + max=99999, + tooltip="Number of frames to keep.", + ), + ], + outputs=[ + io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."), + ], + ) + + @classmethod + def execute(cls, video, start_frame, length): + total_frames = video.shape[0] + start_frame = min(start_frame, max(total_frames - 1, 0)) + end_frame = min(start_frame + length, total_frames) + return io.NodeOutput(video[start_frame:end_frame]) + + +class VideoRandomTemporalCropNode(io.ComfyNode): + """Randomly crop a continuous range of frames from a video (for data augmentation).""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VideoRandomTemporalCrop", + display_name="Video Random Temporal Crop", + category="dataset/video", + is_experimental=True, + inputs=[ + io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."), + io.Int.Input( + "length", + default=16, + min=1, + max=99999, + tooltip="Number of frames to keep.", + ), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed.", + ), + ], + outputs=[ + io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."), + ], + ) + + @classmethod + def execute(cls, video, length, seed): + total_frames = video.shape[0] + length = min(length, total_frames) + max_start = total_frames - length + rng = np.random.RandomState(seed % (2**32 - 1)) + start = rng.randint(0, max_start + 1) if max_start > 0 else 0 + return io.NodeOutput(video[start:start + length]) + + +class ShuffleVideoDatasetNode(ImageProcessingNode): + """Randomly shuffle the order of videos in the dataset.""" + + node_id = "ShuffleVideoDataset" + display_name = "Shuffle Video Dataset" + description = "Randomly shuffle the order of videos in the dataset." + is_group_process = True + extra_inputs = [ + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def define_schema(cls): + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name, + category="dataset/video", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Image.Input("images", tooltip="List of videos to shuffle."), + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ], + outputs=[ + io.Image.Output( + display_name="videos", + is_output_list=True, + tooltip="Shuffled videos", + ), + ], + ) + + @classmethod + def execute(cls, images, seed): + seed = seed[0] if isinstance(seed, list) else seed + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + return io.NodeOutput([images[i] for i in indices]) + + +class ShuffleVideoTextDatasetNode(io.ComfyNode): + """Shuffle videos and their captions together, preserving pairs.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ShuffleVideoTextDataset", + display_name="Shuffle Video-Text Dataset", + category="dataset/video", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Image.Input("videos", tooltip="List of videos to shuffle."), + io.String.Input("texts", tooltip="List of texts to shuffle."), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed.", + ), + ], + outputs=[ + io.Image.Output( + display_name="videos", + is_output_list=True, + tooltip="Shuffled videos", + ), + io.String.Output( + display_name="texts", + is_output_list=True, + tooltip="Shuffled texts", + ), + ], + ) + + @classmethod + def execute(cls, videos, texts, seed): + seed = seed[0] if isinstance(seed, list) else seed + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(videos)) + return io.NodeOutput( + [videos[i] for i in indices], + [texts[i] for i in indices], + ) + + # ========== Text Transform Nodes ========== @@ -1502,7 +1963,10 @@ class DatasetExtension(ComfyExtension): LoadImageTextDataSetFromFolderNode, SaveImageDataSetToFolderNode, SaveImageTextDataSetToFolderNode, - # Image transform nodes + # Video data loading nodes + LoadVideoDataSetFromFolderNode, + LoadVideoTextDataSetFromFolderNode, + # Image transform nodes (auto-handle video via per-frame processing) ResizeImagesByShorterEdgeNode, ResizeImagesByLongerEdgeNode, CenterCropImagesNode, @@ -1512,6 +1976,12 @@ class DatasetExtension(ComfyExtension): AdjustContrastNode, ShuffleDatasetNode, ShuffleImageTextDatasetNode, + # Video processing nodes + VideoFrameSampleNode, + VideoTemporalCropNode, + VideoRandomTemporalCropNode, + ShuffleVideoDatasetNode, + ShuffleVideoTextDatasetNode, # Text transform nodes TextToLowercaseNode, TextToUppercaseNode,