mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-02 21:32:31 +08:00
Init implementation of video dataset nodes
This commit is contained in:
parent
b615af1c65
commit
d364d3f8b5
@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
import json
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -42,6 +43,49 @@ def load_and_process_images(image_files, input_dir):
|
||||
return output_images
|
||||
|
||||
|
||||
VALID_VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"]
|
||||
|
||||
|
||||
def load_video_frames(video_path, max_frames=0, frame_stride=1, start_frame=0):
|
||||
"""Load video file and return frames as a tensor.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file
|
||||
max_frames: Maximum number of frames to load (0 = all)
|
||||
frame_stride: Sample every Nth frame
|
||||
start_frame: Frame index to start from
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Video frames as [T, H, W, C] float32 tensor in [0, 1]
|
||||
"""
|
||||
container = av.open(video_path)
|
||||
stream = container.streams.video[0]
|
||||
|
||||
frames = []
|
||||
frame_idx = 0
|
||||
for frame in container.decode(stream):
|
||||
if frame_idx < start_frame:
|
||||
frame_idx += 1
|
||||
continue
|
||||
if (frame_idx - start_frame) % frame_stride != 0:
|
||||
frame_idx += 1
|
||||
continue
|
||||
if max_frames > 0 and len(frames) >= max_frames:
|
||||
break
|
||||
|
||||
img = frame.to_ndarray(format='rgb24')
|
||||
img_tensor = torch.from_numpy(img.copy()).float() / 255.0
|
||||
frames.append(img_tensor)
|
||||
frame_idx += 1
|
||||
|
||||
container.close()
|
||||
|
||||
if not frames:
|
||||
raise ValueError(f"No frames could be extracted from {video_path}")
|
||||
|
||||
return torch.stack(frames) # [T, H, W, C]
|
||||
|
||||
|
||||
class LoadImageDataSetFromFolderNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -153,6 +197,166 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
|
||||
return io.NodeOutput(output_tensor, captions)
|
||||
|
||||
|
||||
class LoadVideoDataSetFromFolderNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadVideoDataSetFromFolder",
|
||||
display_name="Load Video Dataset from Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"folder",
|
||||
options=folder_paths.get_input_subfolders(),
|
||||
tooltip="The folder containing video files.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"max_frames",
|
||||
default=0,
|
||||
min=0,
|
||||
max=99999,
|
||||
tooltip="Maximum frames to load per video (0 = all frames).",
|
||||
),
|
||||
io.Int.Input(
|
||||
"frame_stride",
|
||||
default=1,
|
||||
min=1,
|
||||
max=1000,
|
||||
tooltip="Sample every Nth frame (1 = every frame).",
|
||||
),
|
||||
io.Int.Input(
|
||||
"start_frame",
|
||||
default=0,
|
||||
min=0,
|
||||
max=99999,
|
||||
tooltip="Frame index to start loading from.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(
|
||||
display_name="videos",
|
||||
is_output_list=True,
|
||||
tooltip="List of video tensors, each [T, H, W, C].",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, folder, max_frames, frame_stride, start_frame):
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
video_files = sorted([
|
||||
f for f in os.listdir(sub_input_dir)
|
||||
if any(f.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS)
|
||||
])
|
||||
|
||||
if not video_files:
|
||||
raise ValueError(f"No video files found in {sub_input_dir}")
|
||||
|
||||
output_videos = []
|
||||
for file in video_files:
|
||||
video_path = os.path.join(sub_input_dir, file)
|
||||
frames = load_video_frames(video_path, max_frames, frame_stride, start_frame)
|
||||
output_videos.append(frames)
|
||||
logging.info(f"Loaded {file}: {frames.shape[0]} frames, {frames.shape[1]}x{frames.shape[2]}")
|
||||
|
||||
logging.info(f"Loaded {len(output_videos)} videos from {sub_input_dir}")
|
||||
return io.NodeOutput(output_videos)
|
||||
|
||||
|
||||
class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadVideoTextDataSetFromFolder",
|
||||
display_name="Load Video and Text Dataset from Folder",
|
||||
category="dataset",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"folder",
|
||||
options=folder_paths.get_input_subfolders(),
|
||||
tooltip="The folder containing video files and .txt captions.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"max_frames",
|
||||
default=0,
|
||||
min=0,
|
||||
max=99999,
|
||||
tooltip="Maximum frames to load per video (0 = all frames).",
|
||||
),
|
||||
io.Int.Input(
|
||||
"frame_stride",
|
||||
default=1,
|
||||
min=1,
|
||||
max=1000,
|
||||
tooltip="Sample every Nth frame (1 = every frame).",
|
||||
),
|
||||
io.Int.Input(
|
||||
"start_frame",
|
||||
default=0,
|
||||
min=0,
|
||||
max=99999,
|
||||
tooltip="Frame index to start loading from.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(
|
||||
display_name="videos",
|
||||
is_output_list=True,
|
||||
tooltip="List of video tensors, each [T, H, W, C].",
|
||||
),
|
||||
io.String.Output(
|
||||
display_name="texts",
|
||||
is_output_list=True,
|
||||
tooltip="List of text captions.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, folder, max_frames, frame_stride, start_frame):
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
|
||||
video_files = []
|
||||
for item in sorted(os.listdir(sub_input_dir)):
|
||||
path = os.path.join(sub_input_dir, item)
|
||||
if any(item.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS):
|
||||
video_files.append(path)
|
||||
elif os.path.isdir(path):
|
||||
# Support kohya-ss/sd-scripts folder structure: {repeat}_{desc}/
|
||||
repeat = 1
|
||||
if item.split("_")[0].isdigit():
|
||||
repeat = int(item.split("_")[0])
|
||||
video_files.extend([
|
||||
os.path.join(path, f)
|
||||
for f in sorted(os.listdir(path))
|
||||
if any(f.lower().endswith(ext) for ext in VALID_VIDEO_EXTENSIONS)
|
||||
] * repeat)
|
||||
|
||||
if not video_files:
|
||||
raise ValueError(f"No video files found in {sub_input_dir}")
|
||||
|
||||
# Load captions (same name as video but .txt)
|
||||
captions = []
|
||||
for vf in video_files:
|
||||
caption_path = os.path.splitext(vf)[0] + ".txt"
|
||||
if os.path.exists(caption_path):
|
||||
with open(caption_path, "r", encoding="utf-8") as f:
|
||||
captions.append(f.read().strip())
|
||||
else:
|
||||
captions.append("")
|
||||
|
||||
# Load videos
|
||||
output_videos = []
|
||||
for vf in video_files:
|
||||
frames = load_video_frames(vf, max_frames, frame_stride, start_frame)
|
||||
output_videos.append(frames)
|
||||
|
||||
logging.info(f"Loaded {len(output_videos)} videos with captions from {sub_input_dir}")
|
||||
return io.NodeOutput(output_videos, captions)
|
||||
|
||||
|
||||
def save_images_to_folder(image_list, output_dir, prefix="image"):
|
||||
"""Utility function to save a list of image tensors to disk.
|
||||
|
||||
@ -418,7 +622,15 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, **kwargs):
|
||||
"""Execute the node. Routes to _process or _group_process based on mode."""
|
||||
"""Execute the node. Routes to _process or _group_process based on mode.
|
||||
|
||||
For individual processing (_process), automatically handles multi-frame
|
||||
inputs (video tensors [T, H, W, C]) by applying _process per-frame and
|
||||
concatenating the results. This allows all spatial transform nodes to
|
||||
work with video without modification. Nodes that natively handle batched
|
||||
tensors (e.g. pure tensor math) can set per_frame_process = False to
|
||||
skip the per-frame loop.
|
||||
"""
|
||||
is_group = cls._detect_processing_mode()
|
||||
|
||||
# Extract scalar values from lists for parameters
|
||||
@ -434,7 +646,16 @@ class ImageProcessingNode(io.ComfyNode):
|
||||
result = cls._group_process(images, **params)
|
||||
else:
|
||||
# Individual processing: images is single item, call _process
|
||||
result = cls._process(images, **params)
|
||||
# Auto-loop over frames for multi-frame inputs (video [T, H, W, C])
|
||||
# so that PIL-based spatial transforms work per-frame automatically.
|
||||
if images.shape[0] > 1 and getattr(cls, 'per_frame_process', True):
|
||||
results = []
|
||||
for i in range(images.shape[0]):
|
||||
frame_result = cls._process(images[i:i + 1], **params)
|
||||
results.append(frame_result)
|
||||
result = torch.cat(results, dim=0)
|
||||
else:
|
||||
result = cls._process(images, **params)
|
||||
|
||||
return io.NodeOutput(result)
|
||||
|
||||
@ -736,6 +957,7 @@ class NormalizeImagesNode(ImageProcessingNode):
|
||||
node_id = "NormalizeImages"
|
||||
display_name = "Normalize Images"
|
||||
description = "Normalize images using mean and standard deviation."
|
||||
per_frame_process = False # Pure tensor math, handles any batch size
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"mean",
|
||||
@ -764,6 +986,7 @@ class AdjustBrightnessNode(ImageProcessingNode):
|
||||
node_id = "AdjustBrightness"
|
||||
display_name = "Adjust Brightness"
|
||||
description = "Adjust brightness of all images."
|
||||
per_frame_process = False # Pure tensor math, handles any batch size
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -783,6 +1006,7 @@ class AdjustContrastNode(ImageProcessingNode):
|
||||
node_id = "AdjustContrast"
|
||||
display_name = "Adjust Contrast"
|
||||
description = "Adjust contrast of all images."
|
||||
per_frame_process = False # Pure tensor math, handles any batch size
|
||||
extra_inputs = [
|
||||
io.Float.Input(
|
||||
"factor",
|
||||
@ -860,6 +1084,243 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
return io.NodeOutput(shuffled_images, shuffled_texts)
|
||||
|
||||
|
||||
# ========== Video Processing Nodes ==========
|
||||
|
||||
|
||||
class VideoFrameSampleNode(io.ComfyNode):
|
||||
"""Sample a fixed number of frames from a video using various strategies."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoFrameSample",
|
||||
display_name="Video Frame Sample",
|
||||
category="dataset/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
|
||||
io.Int.Input(
|
||||
"num_frames",
|
||||
default=16,
|
||||
min=1,
|
||||
max=9999,
|
||||
tooltip="Number of frames to sample.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"strategy",
|
||||
options=["uniform", "head", "tail", "random"],
|
||||
default="uniform",
|
||||
tooltip="uniform: evenly spaced, head: first N, tail: last N, random: random sorted.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
tooltip="Random seed (only used with 'random' strategy).",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="video", tooltip="Sampled video [N, H, W, C]."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video, num_frames, strategy, seed):
|
||||
total_frames = video.shape[0]
|
||||
num_frames = min(num_frames, total_frames)
|
||||
|
||||
if strategy == "head":
|
||||
indices = list(range(num_frames))
|
||||
elif strategy == "tail":
|
||||
indices = list(range(total_frames - num_frames, total_frames))
|
||||
elif strategy == "uniform":
|
||||
if num_frames == 1:
|
||||
indices = [total_frames // 2]
|
||||
else:
|
||||
indices = [round(i * (total_frames - 1) / (num_frames - 1)) for i in range(num_frames)]
|
||||
elif strategy == "random":
|
||||
rng = np.random.RandomState(seed % (2**32 - 1))
|
||||
indices = sorted(rng.choice(total_frames, size=num_frames, replace=False).tolist())
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy}")
|
||||
|
||||
return io.NodeOutput(video[indices])
|
||||
|
||||
|
||||
class VideoTemporalCropNode(io.ComfyNode):
|
||||
"""Crop a continuous range of frames from a video."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoTemporalCrop",
|
||||
display_name="Video Temporal Crop",
|
||||
category="dataset/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
|
||||
io.Int.Input(
|
||||
"start_frame",
|
||||
default=0,
|
||||
min=0,
|
||||
max=99999,
|
||||
tooltip="Starting frame index.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"length",
|
||||
default=16,
|
||||
min=1,
|
||||
max=99999,
|
||||
tooltip="Number of frames to keep.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video, start_frame, length):
|
||||
total_frames = video.shape[0]
|
||||
start_frame = min(start_frame, max(total_frames - 1, 0))
|
||||
end_frame = min(start_frame + length, total_frames)
|
||||
return io.NodeOutput(video[start_frame:end_frame])
|
||||
|
||||
|
||||
class VideoRandomTemporalCropNode(io.ComfyNode):
|
||||
"""Randomly crop a continuous range of frames from a video (for data augmentation)."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoRandomTemporalCrop",
|
||||
display_name="Video Random Temporal Crop",
|
||||
category="dataset/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
|
||||
io.Int.Input(
|
||||
"length",
|
||||
default=16,
|
||||
min=1,
|
||||
max=99999,
|
||||
tooltip="Number of frames to keep.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
tooltip="Random seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video, length, seed):
|
||||
total_frames = video.shape[0]
|
||||
length = min(length, total_frames)
|
||||
max_start = total_frames - length
|
||||
rng = np.random.RandomState(seed % (2**32 - 1))
|
||||
start = rng.randint(0, max_start + 1) if max_start > 0 else 0
|
||||
return io.NodeOutput(video[start:start + length])
|
||||
|
||||
|
||||
class ShuffleVideoDatasetNode(ImageProcessingNode):
|
||||
"""Randomly shuffle the order of videos in the dataset."""
|
||||
|
||||
node_id = "ShuffleVideoDataset"
|
||||
display_name = "Shuffle Video Dataset"
|
||||
description = "Randomly shuffle the order of videos in the dataset."
|
||||
is_group_process = True
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id=cls.node_id,
|
||||
display_name=cls.display_name,
|
||||
category="dataset/video",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of videos to shuffle."),
|
||||
io.Int.Input(
|
||||
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(
|
||||
display_name="videos",
|
||||
is_output_list=True,
|
||||
tooltip="Shuffled videos",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, seed):
|
||||
seed = seed[0] if isinstance(seed, list) else seed
|
||||
np.random.seed(seed % (2**32 - 1))
|
||||
indices = np.random.permutation(len(images))
|
||||
return io.NodeOutput([images[i] for i in indices])
|
||||
|
||||
|
||||
class ShuffleVideoTextDatasetNode(io.ComfyNode):
|
||||
"""Shuffle videos and their captions together, preserving pairs."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ShuffleVideoTextDataset",
|
||||
display_name="Shuffle Video-Text Dataset",
|
||||
category="dataset/video",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Image.Input("videos", tooltip="List of videos to shuffle."),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle."),
|
||||
io.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
tooltip="Random seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(
|
||||
display_name="videos",
|
||||
is_output_list=True,
|
||||
tooltip="Shuffled videos",
|
||||
),
|
||||
io.String.Output(
|
||||
display_name="texts",
|
||||
is_output_list=True,
|
||||
tooltip="Shuffled texts",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, videos, texts, seed):
|
||||
seed = seed[0] if isinstance(seed, list) else seed
|
||||
np.random.seed(seed % (2**32 - 1))
|
||||
indices = np.random.permutation(len(videos))
|
||||
return io.NodeOutput(
|
||||
[videos[i] for i in indices],
|
||||
[texts[i] for i in indices],
|
||||
)
|
||||
|
||||
|
||||
# ========== Text Transform Nodes ==========
|
||||
|
||||
|
||||
@ -1502,7 +1963,10 @@ class DatasetExtension(ComfyExtension):
|
||||
LoadImageTextDataSetFromFolderNode,
|
||||
SaveImageDataSetToFolderNode,
|
||||
SaveImageTextDataSetToFolderNode,
|
||||
# Image transform nodes
|
||||
# Video data loading nodes
|
||||
LoadVideoDataSetFromFolderNode,
|
||||
LoadVideoTextDataSetFromFolderNode,
|
||||
# Image transform nodes (auto-handle video via per-frame processing)
|
||||
ResizeImagesByShorterEdgeNode,
|
||||
ResizeImagesByLongerEdgeNode,
|
||||
CenterCropImagesNode,
|
||||
@ -1512,6 +1976,12 @@ class DatasetExtension(ComfyExtension):
|
||||
AdjustContrastNode,
|
||||
ShuffleDatasetNode,
|
||||
ShuffleImageTextDatasetNode,
|
||||
# Video processing nodes
|
||||
VideoFrameSampleNode,
|
||||
VideoTemporalCropNode,
|
||||
VideoRandomTemporalCropNode,
|
||||
ShuffleVideoDatasetNode,
|
||||
ShuffleVideoTextDatasetNode,
|
||||
# Text transform nodes
|
||||
TextToLowercaseNode,
|
||||
TextToUppercaseNode,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user