Init implementation of video dataset nodes

This commit is contained in:
Kohaku-Blueleaf 2026-04-14 14:39:56 +08:00
parent b615af1c65
commit d364d3f8b5

View File

@ -2,6 +2,7 @@ import logging
import os import os
import json import json
import av
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
@ -42,6 +43,49 @@ def load_and_process_images(image_files, input_dir):
return output_images return output_images
VALID_VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"]
def load_video_frames(video_path, max_frames=0, frame_stride=1, start_frame=0):
"""Load video file and return frames as a tensor.
Args:
video_path: Path to the video file
max_frames: Maximum number of frames to load (0 = all)
frame_stride: Sample every Nth frame
start_frame: Frame index to start from
Returns:
torch.Tensor: Video frames as [T, H, W, C] float32 tensor in [0, 1]
"""
container = av.open(video_path)
stream = container.streams.video[0]
frames = []
frame_idx = 0
for frame in container.decode(stream):
if frame_idx < start_frame:
frame_idx += 1
continue
if (frame_idx - start_frame) % frame_stride != 0:
frame_idx += 1
continue
if max_frames > 0 and len(frames) >= max_frames:
break
img = frame.to_ndarray(format='rgb24')
img_tensor = torch.from_numpy(img.copy()).float() / 255.0
frames.append(img_tensor)
frame_idx += 1
container.close()
if not frames:
raise ValueError(f"No frames could be extracted from {video_path}")
return torch.stack(frames) # [T, H, W, C]
class LoadImageDataSetFromFolderNode(io.ComfyNode): class LoadImageDataSetFromFolderNode(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -153,6 +197,166 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
return io.NodeOutput(output_tensor, captions) return io.NodeOutput(output_tensor, captions)
class LoadVideoDataSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadVideoDataSetFromFolder",
display_name="Load Video Dataset from Folder",
category="dataset",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder",
options=folder_paths.get_input_subfolders(),
tooltip="The folder containing video files.",
),
io.Int.Input(
"max_frames",
default=0,
min=0,
max=99999,
tooltip="Maximum frames to load per video (0 = all frames).",
),
io.Int.Input(
"frame_stride",
default=1,
min=1,
max=1000,
tooltip="Sample every Nth frame (1 = every frame).",
),
io.Int.Input(
"start_frame",
default=0,
min=0,
max=99999,
tooltip="Frame index to start loading from.",
),
],
outputs=[
io.Image.Output(
display_name="videos",
is_output_list=True,
tooltip="List of video tensors, each [T, H, W, C].",
),
],
)
@classmethod
def execute(cls, folder, max_frames, frame_stride, start_frame):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
video_files = sorted([
f for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS)
])
if not video_files:
raise ValueError(f"No video files found in {sub_input_dir}")
output_videos = []
for file in video_files:
video_path = os.path.join(sub_input_dir, file)
frames = load_video_frames(video_path, max_frames, frame_stride, start_frame)
output_videos.append(frames)
logging.info(f"Loaded {file}: {frames.shape[0]} frames, {frames.shape[1]}x{frames.shape[2]}")
logging.info(f"Loaded {len(output_videos)} videos from {sub_input_dir}")
return io.NodeOutput(output_videos)
class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadVideoTextDataSetFromFolder",
display_name="Load Video and Text Dataset from Folder",
category="dataset",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder",
options=folder_paths.get_input_subfolders(),
tooltip="The folder containing video files and .txt captions.",
),
io.Int.Input(
"max_frames",
default=0,
min=0,
max=99999,
tooltip="Maximum frames to load per video (0 = all frames).",
),
io.Int.Input(
"frame_stride",
default=1,
min=1,
max=1000,
tooltip="Sample every Nth frame (1 = every frame).",
),
io.Int.Input(
"start_frame",
default=0,
min=0,
max=99999,
tooltip="Frame index to start loading from.",
),
],
outputs=[
io.Image.Output(
display_name="videos",
is_output_list=True,
tooltip="List of video tensors, each [T, H, W, C].",
),
io.String.Output(
display_name="texts",
is_output_list=True,
tooltip="List of text captions.",
),
],
)
@classmethod
def execute(cls, folder, max_frames, frame_stride, start_frame):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
video_files = []
for item in sorted(os.listdir(sub_input_dir)):
path = os.path.join(sub_input_dir, item)
if any(item.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS):
video_files.append(path)
elif os.path.isdir(path):
# Support kohya-ss/sd-scripts folder structure: {repeat}_{desc}/
repeat = 1
if item.split("_")[0].isdigit():
repeat = int(item.split("_")[0])
video_files.extend([
os.path.join(path, f)
for f in sorted(os.listdir(path))
if any(f.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS)
] * repeat)
if not video_files:
raise ValueError(f"No video files found in {sub_input_dir}")
# Load captions (same name as video but .txt)
captions = []
for vf in video_files:
caption_path = os.path.splitext(vf)[0] + ".txt"
if os.path.exists(caption_path):
with open(caption_path, "r", encoding="utf-8") as f:
captions.append(f.read().strip())
else:
captions.append("")
# Load videos
output_videos = []
for vf in video_files:
frames = load_video_frames(vf, max_frames, frame_stride, start_frame)
output_videos.append(frames)
logging.info(f"Loaded {len(output_videos)} videos with captions from {sub_input_dir}")
return io.NodeOutput(output_videos, captions)
def save_images_to_folder(image_list, output_dir, prefix="image"): def save_images_to_folder(image_list, output_dir, prefix="image"):
"""Utility function to save a list of image tensors to disk. """Utility function to save a list of image tensors to disk.
@ -418,7 +622,15 @@ class ImageProcessingNode(io.ComfyNode):
@classmethod @classmethod
def execute(cls, images, **kwargs): def execute(cls, images, **kwargs):
"""Execute the node. Routes to _process or _group_process based on mode.""" """Execute the node. Routes to _process or _group_process based on mode.
For individual processing (_process), automatically handles multi-frame
inputs (video tensors [T, H, W, C]) by applying _process per-frame and
concatenating the results. This allows all spatial transform nodes to
work with video without modification. Nodes that natively handle batched
tensors (e.g. pure tensor math) can set per_frame_process = False to
skip the per-frame loop.
"""
is_group = cls._detect_processing_mode() is_group = cls._detect_processing_mode()
# Extract scalar values from lists for parameters # Extract scalar values from lists for parameters
@ -434,7 +646,16 @@ class ImageProcessingNode(io.ComfyNode):
result = cls._group_process(images, **params) result = cls._group_process(images, **params)
else: else:
# Individual processing: images is single item, call _process # Individual processing: images is single item, call _process
result = cls._process(images, **params) # Auto-loop over frames for multi-frame inputs (video [T, H, W, C])
# so that PIL-based spatial transforms work per-frame automatically.
if images.shape[0] > 1 and getattr(cls, 'per_frame_process', True):
results = []
for i in range(images.shape[0]):
frame_result = cls._process(images[i:i + 1], **params)
results.append(frame_result)
result = torch.cat(results, dim=0)
else:
result = cls._process(images, **params)
return io.NodeOutput(result) return io.NodeOutput(result)
@ -736,6 +957,7 @@ class NormalizeImagesNode(ImageProcessingNode):
node_id = "NormalizeImages" node_id = "NormalizeImages"
display_name = "Normalize Images" display_name = "Normalize Images"
description = "Normalize images using mean and standard deviation." description = "Normalize images using mean and standard deviation."
per_frame_process = False # Pure tensor math, handles any batch size
extra_inputs = [ extra_inputs = [
io.Float.Input( io.Float.Input(
"mean", "mean",
@ -764,6 +986,7 @@ class AdjustBrightnessNode(ImageProcessingNode):
node_id = "AdjustBrightness" node_id = "AdjustBrightness"
display_name = "Adjust Brightness" display_name = "Adjust Brightness"
description = "Adjust brightness of all images." description = "Adjust brightness of all images."
per_frame_process = False # Pure tensor math, handles any batch size
extra_inputs = [ extra_inputs = [
io.Float.Input( io.Float.Input(
"factor", "factor",
@ -783,6 +1006,7 @@ class AdjustContrastNode(ImageProcessingNode):
node_id = "AdjustContrast" node_id = "AdjustContrast"
display_name = "Adjust Contrast" display_name = "Adjust Contrast"
description = "Adjust contrast of all images." description = "Adjust contrast of all images."
per_frame_process = False # Pure tensor math, handles any batch size
extra_inputs = [ extra_inputs = [
io.Float.Input( io.Float.Input(
"factor", "factor",
@ -860,6 +1084,243 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
return io.NodeOutput(shuffled_images, shuffled_texts) return io.NodeOutput(shuffled_images, shuffled_texts)
# ========== Video Processing Nodes ==========
class VideoFrameSampleNode(io.ComfyNode):
"""Sample a fixed number of frames from a video using various strategies."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoFrameSample",
display_name="Video Frame Sample",
category="dataset/video",
is_experimental=True,
inputs=[
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
io.Int.Input(
"num_frames",
default=16,
min=1,
max=9999,
tooltip="Number of frames to sample.",
),
io.Combo.Input(
"strategy",
options=["uniform", "head", "tail", "random"],
default="uniform",
tooltip="uniform: evenly spaced, head: first N, tail: last N, random: random sorted.",
),
io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
tooltip="Random seed (only used with 'random' strategy).",
),
],
outputs=[
io.Image.Output(display_name="video", tooltip="Sampled video [N, H, W, C]."),
],
)
@classmethod
def execute(cls, video, num_frames, strategy, seed):
total_frames = video.shape[0]
num_frames = min(num_frames, total_frames)
if strategy == "head":
indices = list(range(num_frames))
elif strategy == "tail":
indices = list(range(total_frames - num_frames, total_frames))
elif strategy == "uniform":
if num_frames == 1:
indices = [total_frames // 2]
else:
indices = [round(i * (total_frames - 1) / (num_frames - 1)) for i in range(num_frames)]
elif strategy == "random":
rng = np.random.RandomState(seed % (2**32 - 1))
indices = sorted(rng.choice(total_frames, size=num_frames, replace=False).tolist())
else:
raise ValueError(f"Unknown strategy: {strategy}")
return io.NodeOutput(video[indices])
class VideoTemporalCropNode(io.ComfyNode):
"""Crop a continuous range of frames from a video."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoTemporalCrop",
display_name="Video Temporal Crop",
category="dataset/video",
is_experimental=True,
inputs=[
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
io.Int.Input(
"start_frame",
default=0,
min=0,
max=99999,
tooltip="Starting frame index.",
),
io.Int.Input(
"length",
default=16,
min=1,
max=99999,
tooltip="Number of frames to keep.",
),
],
outputs=[
io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."),
],
)
@classmethod
def execute(cls, video, start_frame, length):
total_frames = video.shape[0]
start_frame = min(start_frame, max(total_frames - 1, 0))
end_frame = min(start_frame + length, total_frames)
return io.NodeOutput(video[start_frame:end_frame])
class VideoRandomTemporalCropNode(io.ComfyNode):
"""Randomly crop a continuous range of frames from a video (for data augmentation)."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoRandomTemporalCrop",
display_name="Video Random Temporal Crop",
category="dataset/video",
is_experimental=True,
inputs=[
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
io.Int.Input(
"length",
default=16,
min=1,
max=99999,
tooltip="Number of frames to keep.",
),
io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
tooltip="Random seed.",
),
],
outputs=[
io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."),
],
)
@classmethod
def execute(cls, video, length, seed):
total_frames = video.shape[0]
length = min(length, total_frames)
max_start = total_frames - length
rng = np.random.RandomState(seed % (2**32 - 1))
start = rng.randint(0, max_start + 1) if max_start > 0 else 0
return io.NodeOutput(video[start:start + length])
class ShuffleVideoDatasetNode(ImageProcessingNode):
"""Randomly shuffle the order of videos in the dataset."""
node_id = "ShuffleVideoDataset"
display_name = "Shuffle Video Dataset"
description = "Randomly shuffle the order of videos in the dataset."
is_group_process = True
extra_inputs = [
io.Int.Input(
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
),
]
@classmethod
def define_schema(cls):
return io.Schema(
node_id=cls.node_id,
display_name=cls.display_name,
category="dataset/video",
is_experimental=True,
is_input_list=True,
inputs=[
io.Image.Input("images", tooltip="List of videos to shuffle."),
io.Int.Input(
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
),
],
outputs=[
io.Image.Output(
display_name="videos",
is_output_list=True,
tooltip="Shuffled videos",
),
],
)
@classmethod
def execute(cls, images, seed):
seed = seed[0] if isinstance(seed, list) else seed
np.random.seed(seed % (2**32 - 1))
indices = np.random.permutation(len(images))
return io.NodeOutput([images[i] for i in indices])
class ShuffleVideoTextDatasetNode(io.ComfyNode):
"""Shuffle videos and their captions together, preserving pairs."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ShuffleVideoTextDataset",
display_name="Shuffle Video-Text Dataset",
category="dataset/video",
is_experimental=True,
is_input_list=True,
inputs=[
io.Image.Input("videos", tooltip="List of videos to shuffle."),
io.String.Input("texts", tooltip="List of texts to shuffle."),
io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
tooltip="Random seed.",
),
],
outputs=[
io.Image.Output(
display_name="videos",
is_output_list=True,
tooltip="Shuffled videos",
),
io.String.Output(
display_name="texts",
is_output_list=True,
tooltip="Shuffled texts",
),
],
)
@classmethod
def execute(cls, videos, texts, seed):
seed = seed[0] if isinstance(seed, list) else seed
np.random.seed(seed % (2**32 - 1))
indices = np.random.permutation(len(videos))
return io.NodeOutput(
[videos[i] for i in indices],
[texts[i] for i in indices],
)
# ========== Text Transform Nodes ========== # ========== Text Transform Nodes ==========
@ -1502,7 +1963,10 @@ class DatasetExtension(ComfyExtension):
LoadImageTextDataSetFromFolderNode, LoadImageTextDataSetFromFolderNode,
SaveImageDataSetToFolderNode, SaveImageDataSetToFolderNode,
SaveImageTextDataSetToFolderNode, SaveImageTextDataSetToFolderNode,
# Image transform nodes # Video data loading nodes
LoadVideoDataSetFromFolderNode,
LoadVideoTextDataSetFromFolderNode,
# Image transform nodes (auto-handle video via per-frame processing)
ResizeImagesByShorterEdgeNode, ResizeImagesByShorterEdgeNode,
ResizeImagesByLongerEdgeNode, ResizeImagesByLongerEdgeNode,
CenterCropImagesNode, CenterCropImagesNode,
@ -1512,6 +1976,12 @@ class DatasetExtension(ComfyExtension):
AdjustContrastNode, AdjustContrastNode,
ShuffleDatasetNode, ShuffleDatasetNode,
ShuffleImageTextDatasetNode, ShuffleImageTextDatasetNode,
# Video processing nodes
VideoFrameSampleNode,
VideoTemporalCropNode,
VideoRandomTemporalCropNode,
ShuffleVideoDatasetNode,
ShuffleVideoTextDatasetNode,
# Text transform nodes # Text transform nodes
TextToLowercaseNode, TextToLowercaseNode,
TextToUppercaseNode, TextToUppercaseNode,