Init implementation of video dataset nodes

This commit is contained in:
Kohaku-Blueleaf 2026-04-14 14:39:56 +08:00
parent b615af1c65
commit d364d3f8b5

View File

@ -2,6 +2,7 @@ import logging
import os
import json
import av
import numpy as np
import torch
from PIL import Image
@ -42,6 +43,49 @@ def load_and_process_images(image_files, input_dir):
return output_images
VALID_VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"]
def load_video_frames(video_path, max_frames=0, frame_stride=1, start_frame=0):
"""Load video file and return frames as a tensor.
Args:
video_path: Path to the video file
max_frames: Maximum number of frames to load (0 = all)
frame_stride: Sample every Nth frame
start_frame: Frame index to start from
Returns:
torch.Tensor: Video frames as [T, H, W, C] float32 tensor in [0, 1]
"""
container = av.open(video_path)
stream = container.streams.video[0]
frames = []
frame_idx = 0
for frame in container.decode(stream):
if frame_idx < start_frame:
frame_idx += 1
continue
if (frame_idx - start_frame) % frame_stride != 0:
frame_idx += 1
continue
if max_frames > 0 and len(frames) >= max_frames:
break
img = frame.to_ndarray(format='rgb24')
img_tensor = torch.from_numpy(img.copy()).float() / 255.0
frames.append(img_tensor)
frame_idx += 1
container.close()
if not frames:
raise ValueError(f"No frames could be extracted from {video_path}")
return torch.stack(frames) # [T, H, W, C]
class LoadImageDataSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -153,6 +197,166 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
return io.NodeOutput(output_tensor, captions)
class LoadVideoDataSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadVideoDataSetFromFolder",
display_name="Load Video Dataset from Folder",
category="dataset",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder",
options=folder_paths.get_input_subfolders(),
tooltip="The folder containing video files.",
),
io.Int.Input(
"max_frames",
default=0,
min=0,
max=99999,
tooltip="Maximum frames to load per video (0 = all frames).",
),
io.Int.Input(
"frame_stride",
default=1,
min=1,
max=1000,
tooltip="Sample every Nth frame (1 = every frame).",
),
io.Int.Input(
"start_frame",
default=0,
min=0,
max=99999,
tooltip="Frame index to start loading from.",
),
],
outputs=[
io.Image.Output(
display_name="videos",
is_output_list=True,
tooltip="List of video tensors, each [T, H, W, C].",
),
],
)
@classmethod
def execute(cls, folder, max_frames, frame_stride, start_frame):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
video_files = sorted([
f for f in os.listdir(sub_input_dir)
if any(f.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS)
])
if not video_files:
raise ValueError(f"No video files found in {sub_input_dir}")
output_videos = []
for file in video_files:
video_path = os.path.join(sub_input_dir, file)
frames = load_video_frames(video_path, max_frames, frame_stride, start_frame)
output_videos.append(frames)
logging.info(f"Loaded {file}: {frames.shape[0]} frames, {frames.shape[1]}x{frames.shape[2]}")
logging.info(f"Loaded {len(output_videos)} videos from {sub_input_dir}")
return io.NodeOutput(output_videos)
class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadVideoTextDataSetFromFolder",
display_name="Load Video and Text Dataset from Folder",
category="dataset",
is_experimental=True,
inputs=[
io.Combo.Input(
"folder",
options=folder_paths.get_input_subfolders(),
tooltip="The folder containing video files and .txt captions.",
),
io.Int.Input(
"max_frames",
default=0,
min=0,
max=99999,
tooltip="Maximum frames to load per video (0 = all frames).",
),
io.Int.Input(
"frame_stride",
default=1,
min=1,
max=1000,
tooltip="Sample every Nth frame (1 = every frame).",
),
io.Int.Input(
"start_frame",
default=0,
min=0,
max=99999,
tooltip="Frame index to start loading from.",
),
],
outputs=[
io.Image.Output(
display_name="videos",
is_output_list=True,
tooltip="List of video tensors, each [T, H, W, C].",
),
io.String.Output(
display_name="texts",
is_output_list=True,
tooltip="List of text captions.",
),
],
)
@classmethod
def execute(cls, folder, max_frames, frame_stride, start_frame):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
video_files = []
for item in sorted(os.listdir(sub_input_dir)):
path = os.path.join(sub_input_dir, item)
if any(item.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS):
video_files.append(path)
elif os.path.isdir(path):
# Support kohya-ss/sd-scripts folder structure: {repeat}_{desc}/
repeat = 1
if item.split("_")[0].isdigit():
repeat = int(item.split("_")[0])
video_files.extend([
os.path.join(path, f)
for f in sorted(os.listdir(path))
if any(f.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS)
] * repeat)
if not video_files:
raise ValueError(f"No video files found in {sub_input_dir}")
# Load captions (same name as video but .txt)
captions = []
for vf in video_files:
caption_path = os.path.splitext(vf)[0] + ".txt"
if os.path.exists(caption_path):
with open(caption_path, "r", encoding="utf-8") as f:
captions.append(f.read().strip())
else:
captions.append("")
# Load videos
output_videos = []
for vf in video_files:
frames = load_video_frames(vf, max_frames, frame_stride, start_frame)
output_videos.append(frames)
logging.info(f"Loaded {len(output_videos)} videos with captions from {sub_input_dir}")
return io.NodeOutput(output_videos, captions)
def save_images_to_folder(image_list, output_dir, prefix="image"):
"""Utility function to save a list of image tensors to disk.
@ -418,7 +622,15 @@ class ImageProcessingNode(io.ComfyNode):
@classmethod
def execute(cls, images, **kwargs):
"""Execute the node. Routes to _process or _group_process based on mode."""
"""Execute the node. Routes to _process or _group_process based on mode.
For individual processing (_process), automatically handles multi-frame
inputs (video tensors [T, H, W, C]) by applying _process per-frame and
concatenating the results. This allows all spatial transform nodes to
work with video without modification. Nodes that natively handle batched
tensors (e.g. pure tensor math) can set per_frame_process = False to
skip the per-frame loop.
"""
is_group = cls._detect_processing_mode()
# Extract scalar values from lists for parameters
@ -434,7 +646,16 @@ class ImageProcessingNode(io.ComfyNode):
result = cls._group_process(images, **params)
else:
# Individual processing: images is single item, call _process
result = cls._process(images, **params)
# Auto-loop over frames for multi-frame inputs (video [T, H, W, C])
# so that PIL-based spatial transforms work per-frame automatically.
if images.shape[0] > 1 and getattr(cls, 'per_frame_process', True):
results = []
for i in range(images.shape[0]):
frame_result = cls._process(images[i:i + 1], **params)
results.append(frame_result)
result = torch.cat(results, dim=0)
else:
result = cls._process(images, **params)
return io.NodeOutput(result)
@ -736,6 +957,7 @@ class NormalizeImagesNode(ImageProcessingNode):
node_id = "NormalizeImages"
display_name = "Normalize Images"
description = "Normalize images using mean and standard deviation."
per_frame_process = False # Pure tensor math, handles any batch size
extra_inputs = [
io.Float.Input(
"mean",
@ -764,6 +986,7 @@ class AdjustBrightnessNode(ImageProcessingNode):
node_id = "AdjustBrightness"
display_name = "Adjust Brightness"
description = "Adjust brightness of all images."
per_frame_process = False # Pure tensor math, handles any batch size
extra_inputs = [
io.Float.Input(
"factor",
@ -783,6 +1006,7 @@ class AdjustContrastNode(ImageProcessingNode):
node_id = "AdjustContrast"
display_name = "Adjust Contrast"
description = "Adjust contrast of all images."
per_frame_process = False # Pure tensor math, handles any batch size
extra_inputs = [
io.Float.Input(
"factor",
@ -860,6 +1084,243 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
return io.NodeOutput(shuffled_images, shuffled_texts)
# ========== Video Processing Nodes ==========
class VideoFrameSampleNode(io.ComfyNode):
"""Sample a fixed number of frames from a video using various strategies."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoFrameSample",
display_name="Video Frame Sample",
category="dataset/video",
is_experimental=True,
inputs=[
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
io.Int.Input(
"num_frames",
default=16,
min=1,
max=9999,
tooltip="Number of frames to sample.",
),
io.Combo.Input(
"strategy",
options=["uniform", "head", "tail", "random"],
default="uniform",
tooltip="uniform: evenly spaced, head: first N, tail: last N, random: random sorted.",
),
io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
tooltip="Random seed (only used with 'random' strategy).",
),
],
outputs=[
io.Image.Output(display_name="video", tooltip="Sampled video [N, H, W, C]."),
],
)
@classmethod
def execute(cls, video, num_frames, strategy, seed):
total_frames = video.shape[0]
num_frames = min(num_frames, total_frames)
if strategy == "head":
indices = list(range(num_frames))
elif strategy == "tail":
indices = list(range(total_frames - num_frames, total_frames))
elif strategy == "uniform":
if num_frames == 1:
indices = [total_frames // 2]
else:
indices = [round(i * (total_frames - 1) / (num_frames - 1)) for i in range(num_frames)]
elif strategy == "random":
rng = np.random.RandomState(seed % (2**32 - 1))
indices = sorted(rng.choice(total_frames, size=num_frames, replace=False).tolist())
else:
raise ValueError(f"Unknown strategy: {strategy}")
return io.NodeOutput(video[indices])
class VideoTemporalCropNode(io.ComfyNode):
"""Crop a continuous range of frames from a video."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoTemporalCrop",
display_name="Video Temporal Crop",
category="dataset/video",
is_experimental=True,
inputs=[
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
io.Int.Input(
"start_frame",
default=0,
min=0,
max=99999,
tooltip="Starting frame index.",
),
io.Int.Input(
"length",
default=16,
min=1,
max=99999,
tooltip="Number of frames to keep.",
),
],
outputs=[
io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."),
],
)
@classmethod
def execute(cls, video, start_frame, length):
total_frames = video.shape[0]
start_frame = min(start_frame, max(total_frames - 1, 0))
end_frame = min(start_frame + length, total_frames)
return io.NodeOutput(video[start_frame:end_frame])
class VideoRandomTemporalCropNode(io.ComfyNode):
"""Randomly crop a continuous range of frames from a video (for data augmentation)."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoRandomTemporalCrop",
display_name="Video Random Temporal Crop",
category="dataset/video",
is_experimental=True,
inputs=[
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
io.Int.Input(
"length",
default=16,
min=1,
max=99999,
tooltip="Number of frames to keep.",
),
io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
tooltip="Random seed.",
),
],
outputs=[
io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."),
],
)
@classmethod
def execute(cls, video, length, seed):
total_frames = video.shape[0]
length = min(length, total_frames)
max_start = total_frames - length
rng = np.random.RandomState(seed % (2**32 - 1))
start = rng.randint(0, max_start + 1) if max_start > 0 else 0
return io.NodeOutput(video[start:start + length])
class ShuffleVideoDatasetNode(ImageProcessingNode):
"""Randomly shuffle the order of videos in the dataset."""
node_id = "ShuffleVideoDataset"
display_name = "Shuffle Video Dataset"
description = "Randomly shuffle the order of videos in the dataset."
is_group_process = True
extra_inputs = [
io.Int.Input(
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
),
]
@classmethod
def define_schema(cls):
return io.Schema(
node_id=cls.node_id,
display_name=cls.display_name,
category="dataset/video",
is_experimental=True,
is_input_list=True,
inputs=[
io.Image.Input("images", tooltip="List of videos to shuffle."),
io.Int.Input(
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
),
],
outputs=[
io.Image.Output(
display_name="videos",
is_output_list=True,
tooltip="Shuffled videos",
),
],
)
@classmethod
def execute(cls, images, seed):
seed = seed[0] if isinstance(seed, list) else seed
np.random.seed(seed % (2**32 - 1))
indices = np.random.permutation(len(images))
return io.NodeOutput([images[i] for i in indices])
class ShuffleVideoTextDatasetNode(io.ComfyNode):
"""Shuffle videos and their captions together, preserving pairs."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ShuffleVideoTextDataset",
display_name="Shuffle Video-Text Dataset",
category="dataset/video",
is_experimental=True,
is_input_list=True,
inputs=[
io.Image.Input("videos", tooltip="List of videos to shuffle."),
io.String.Input("texts", tooltip="List of texts to shuffle."),
io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
tooltip="Random seed.",
),
],
outputs=[
io.Image.Output(
display_name="videos",
is_output_list=True,
tooltip="Shuffled videos",
),
io.String.Output(
display_name="texts",
is_output_list=True,
tooltip="Shuffled texts",
),
],
)
@classmethod
def execute(cls, videos, texts, seed):
seed = seed[0] if isinstance(seed, list) else seed
np.random.seed(seed % (2**32 - 1))
indices = np.random.permutation(len(videos))
return io.NodeOutput(
[videos[i] for i in indices],
[texts[i] for i in indices],
)
# ========== Text Transform Nodes ==========
@ -1502,7 +1963,10 @@ class DatasetExtension(ComfyExtension):
LoadImageTextDataSetFromFolderNode,
SaveImageDataSetToFolderNode,
SaveImageTextDataSetToFolderNode,
# Image transform nodes
# Video data loading nodes
LoadVideoDataSetFromFolderNode,
LoadVideoTextDataSetFromFolderNode,
# Image transform nodes (auto-handle video via per-frame processing)
ResizeImagesByShorterEdgeNode,
ResizeImagesByLongerEdgeNode,
CenterCropImagesNode,
@ -1512,6 +1976,12 @@ class DatasetExtension(ComfyExtension):
AdjustContrastNode,
ShuffleDatasetNode,
ShuffleImageTextDatasetNode,
# Video processing nodes
VideoFrameSampleNode,
VideoTemporalCropNode,
VideoRandomTemporalCropNode,
ShuffleVideoDatasetNode,
ShuffleVideoTextDatasetNode,
# Text transform nodes
TextToLowercaseNode,
TextToUppercaseNode,