refactor(video dataset): lazy video loading and frame-selective decode

- Replace eager load_video_frames() with _decode_selected_frames() that
  opens the container with `with av.open(...)` (no resource leak) and
  decodes only the requested frame indices.
- Video loader nodes now emit lazy VideoFromFile references; sampling and
  temporal-crop nodes operate lazily / decode only selected frames.
This commit is contained in:
Kohaku-Blueleaf 2026-06-09 12:19:06 +08:00
parent 7775d2ab81
commit c7d29a42ba

View File

@ -10,7 +10,7 @@ from typing_extensions import override
import folder_paths
import node_helpers
from comfy_api.latest import ComfyExtension, io
from comfy_api.latest import ComfyExtension, io, Input, InputImpl, Types
def load_and_process_images(image_files, input_dir):
@ -46,44 +46,33 @@ def load_and_process_images(image_files, input_dir):
VALID_VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"]
def load_video_frames(video_path, max_frames=0, frame_stride=1, start_frame=0):
"""Load video file and return frames as a tensor.
def _decode_selected_frames(video: Input.Video, indices: list[int]) -> Input.Video:
"""Decode only the requested frame indices from a video.
Args:
video_path: Path to the video file
max_frames: Maximum number of frames to load (0 = all)
frame_stride: Sample every Nth frame
start_frame: Frame index to start from
Returns:
torch.Tensor: Video frames as [T, H, W, C] float32 tensor in [0, 1]
Opens the underlying container once, decodes frames in presentation order,
keeps only the ones whose index is in ``indices``, and returns the result
wrapped in a VideoFromComponents so it still satisfies the VideoInput
contract for downstream nodes.
"""
container = av.open(video_path)
stream = container.streams.video[0]
indices_sorted = sorted(set(indices))
max_idx = indices_sorted[-1]
source = video.get_stream_source()
frames = []
frame_idx = 0
for frame in container.decode(stream):
if frame_idx < start_frame:
frame_idx += 1
continue
if (frame_idx - start_frame) % frame_stride != 0:
frame_idx += 1
continue
if max_frames > 0 and len(frames) >= max_frames:
break
frames_by_idx: dict[int, torch.Tensor] = {}
with av.open(source, mode="r") as container:
stream = container.streams.video[0]
wanted = set(indices_sorted)
for frame_idx, frame in enumerate(container.decode(stream)):
if frame_idx in wanted:
img = frame.to_ndarray(format="rgb24")
frames_by_idx[frame_idx] = torch.from_numpy(img.copy()).float() / 255.0
if frame_idx >= max_idx:
break
img = frame.to_ndarray(format='rgb24')
img_tensor = torch.from_numpy(img.copy()).float() / 255.0
frames.append(img_tensor)
frame_idx += 1
container.close()
if not frames:
raise ValueError(f"No frames could be extracted from {video_path}")
return torch.stack(frames) # [T, H, W, C]
stacked = torch.stack([frames_by_idx[i] for i in indices])
return InputImpl.VideoFromComponents(
Types.VideoComponents(images=stacked, frame_rate=video.get_frame_rate())
)
class LoadImageDataSetFromFolderNode(io.ComfyNode):
@ -211,39 +200,18 @@ class LoadVideoDataSetFromFolderNode(io.ComfyNode):
options=folder_paths.get_input_subfolders(),
tooltip="The folder containing video files.",
),
io.Int.Input(
"max_frames",
default=0,
min=0,
max=99999,
tooltip="Maximum frames to load per video (0 = all frames).",
),
io.Int.Input(
"frame_stride",
default=1,
min=1,
max=1000,
tooltip="Sample every Nth frame (1 = every frame).",
),
io.Int.Input(
"start_frame",
default=0,
min=0,
max=99999,
tooltip="Frame index to start loading from.",
),
],
outputs=[
io.Image.Output(
io.Video.Output(
display_name="videos",
is_output_list=True,
tooltip="List of video tensors, each [T, H, W, C].",
tooltip="Lazy video references; frames are decoded only when needed downstream.",
),
],
)
@classmethod
def execute(cls, folder, max_frames, frame_stride, start_frame):
def execute(cls, folder):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
video_files = sorted([
f for f in os.listdir(sub_input_dir)
@ -253,15 +221,9 @@ class LoadVideoDataSetFromFolderNode(io.ComfyNode):
if not video_files:
raise ValueError(f"No video files found in {sub_input_dir}")
output_videos = []
for file in video_files:
video_path = os.path.join(sub_input_dir, file)
frames = load_video_frames(video_path, max_frames, frame_stride, start_frame)
output_videos.append(frames)
logging.info(f"Loaded {file}: {frames.shape[0]} frames, {frames.shape[1]}x{frames.shape[2]}")
logging.info(f"Loaded {len(output_videos)} videos from {sub_input_dir}")
return io.NodeOutput(output_videos)
videos = [InputImpl.VideoFromFile(os.path.join(sub_input_dir, f)) for f in video_files]
logging.info(f"Loaded {len(videos)} lazy video references from {sub_input_dir}")
return io.NodeOutput(videos)
class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
@ -278,33 +240,12 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
options=folder_paths.get_input_subfolders(),
tooltip="The folder containing video files and .txt captions.",
),
io.Int.Input(
"max_frames",
default=0,
min=0,
max=99999,
tooltip="Maximum frames to load per video (0 = all frames).",
),
io.Int.Input(
"frame_stride",
default=1,
min=1,
max=1000,
tooltip="Sample every Nth frame (1 = every frame).",
),
io.Int.Input(
"start_frame",
default=0,
min=0,
max=99999,
tooltip="Frame index to start loading from.",
),
],
outputs=[
io.Image.Output(
io.Video.Output(
display_name="videos",
is_output_list=True,
tooltip="List of video tensors, each [T, H, W, C].",
tooltip="Lazy video references; frames are decoded only when needed downstream.",
),
io.String.Output(
display_name="texts",
@ -315,7 +256,7 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
)
@classmethod
def execute(cls, folder, max_frames, frame_stride, start_frame):
def execute(cls, folder):
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
video_files = []
@ -337,7 +278,6 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
if not video_files:
raise ValueError(f"No video files found in {sub_input_dir}")
# Load captions (same name as video but .txt)
captions = []
for vf in video_files:
caption_path = os.path.splitext(vf)[0] + ".txt"
@ -347,14 +287,9 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
else:
captions.append("")
# Load videos
output_videos = []
for vf in video_files:
frames = load_video_frames(vf, max_frames, frame_stride, start_frame)
output_videos.append(frames)
logging.info(f"Loaded {len(output_videos)} videos with captions from {sub_input_dir}")
return io.NodeOutput(output_videos, captions)
videos = [InputImpl.VideoFromFile(vf) for vf in video_files]
logging.info(f"Loaded {len(videos)} lazy video references with captions from {sub_input_dir}")
return io.NodeOutput(videos, captions)
def save_images_to_folder(image_list, output_dir, prefix="image"):
@ -1088,7 +1023,12 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
class VideoFrameSampleNode(io.ComfyNode):
"""Sample a fixed number of frames from a video using various strategies."""
"""Sample a fixed number of frames from a video using various strategies.
For contiguous strategies ("head"/"tail") the result is a fully lazy
VideoInput (no frames decoded). For non-contiguous strategies
("uniform"/"random") only the selected indices are decoded.
"""
@classmethod
def define_schema(cls):
@ -1098,7 +1038,7 @@ class VideoFrameSampleNode(io.ComfyNode):
category="dataset/video",
is_experimental=True,
inputs=[
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
io.Video.Input("video", tooltip="Input video."),
io.Int.Input(
"num_frames",
default=16,
@ -1121,20 +1061,27 @@ class VideoFrameSampleNode(io.ComfyNode):
),
],
outputs=[
io.Image.Output(display_name="video", tooltip="Sampled video [N, H, W, C]."),
io.Video.Output(display_name="video", tooltip="Sampled video."),
],
)
@classmethod
def execute(cls, video, num_frames, strategy, seed):
total_frames = video.shape[0]
total_frames = video.get_frame_count()
num_frames = min(num_frames, total_frames)
fps = float(video.get_frame_rate())
if strategy == "head":
indices = list(range(num_frames))
elif strategy == "tail":
indices = list(range(total_frames - num_frames, total_frames))
elif strategy == "uniform":
return io.NodeOutput(
video.as_trimmed(0.0, num_frames / fps, strict_duration=False)
)
if strategy == "tail":
start_t = (total_frames - num_frames) / fps
return io.NodeOutput(
video.as_trimmed(start_t, num_frames / fps, strict_duration=False)
)
if strategy == "uniform":
if num_frames == 1:
indices = [total_frames // 2]
else:
@ -1145,11 +1092,11 @@ class VideoFrameSampleNode(io.ComfyNode):
else:
raise ValueError(f"Unknown strategy: {strategy}")
return io.NodeOutput(video[indices])
return io.NodeOutput(_decode_selected_frames(video, indices))
class VideoTemporalCropNode(io.ComfyNode):
"""Crop a continuous range of frames from a video."""
"""Crop a continuous range of frames from a video (fully lazy)."""
@classmethod
def define_schema(cls):
@ -1159,7 +1106,7 @@ class VideoTemporalCropNode(io.ComfyNode):
category="dataset/video",
is_experimental=True,
inputs=[
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
io.Video.Input("video", tooltip="Input video."),
io.Int.Input(
"start_frame",
default=0,
@ -1176,20 +1123,23 @@ class VideoTemporalCropNode(io.ComfyNode):
),
],
outputs=[
io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."),
io.Video.Output(display_name="video", tooltip="Cropped video (lazy)."),
],
)
@classmethod
def execute(cls, video, start_frame, length):
total_frames = video.shape[0]
total_frames = video.get_frame_count()
fps = float(video.get_frame_rate())
start_frame = min(start_frame, max(total_frames - 1, 0))
end_frame = min(start_frame + length, total_frames)
return io.NodeOutput(video[start_frame:end_frame])
length = min(length, total_frames - start_frame)
return io.NodeOutput(
video.as_trimmed(start_frame / fps, length / fps, strict_duration=False)
)
class VideoRandomTemporalCropNode(io.ComfyNode):
"""Randomly crop a continuous range of frames from a video (for data augmentation)."""
"""Randomly crop a continuous range of frames from a video (fully lazy)."""
@classmethod
def define_schema(cls):
@ -1199,7 +1149,7 @@ class VideoRandomTemporalCropNode(io.ComfyNode):
category="dataset/video",
is_experimental=True,
inputs=[
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
io.Video.Input("video", tooltip="Input video."),
io.Int.Input(
"length",
default=16,
@ -1216,49 +1166,42 @@ class VideoRandomTemporalCropNode(io.ComfyNode):
),
],
outputs=[
io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."),
io.Video.Output(display_name="video", tooltip="Cropped video (lazy)."),
],
)
@classmethod
def execute(cls, video, length, seed):
total_frames = video.shape[0]
total_frames = video.get_frame_count()
fps = float(video.get_frame_rate())
length = min(length, total_frames)
max_start = total_frames - length
rng = np.random.RandomState(seed % (2**32 - 1))
start = rng.randint(0, max_start + 1) if max_start > 0 else 0
return io.NodeOutput(video[start:start + length])
return io.NodeOutput(
video.as_trimmed(start / fps, length / fps, strict_duration=False)
)
class ShuffleVideoDatasetNode(ImageProcessingNode):
class ShuffleVideoDatasetNode(io.ComfyNode):
"""Randomly shuffle the order of videos in the dataset."""
node_id = "ShuffleVideoDataset"
display_name = "Shuffle Video Dataset"
description = "Randomly shuffle the order of videos in the dataset."
is_group_process = True
extra_inputs = [
io.Int.Input(
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
),
]
@classmethod
def define_schema(cls):
return io.Schema(
node_id=cls.node_id,
display_name=cls.display_name,
node_id="ShuffleVideoDataset",
display_name="Shuffle Video Dataset",
category="dataset/video",
is_experimental=True,
is_input_list=True,
inputs=[
io.Image.Input("images", tooltip="List of videos to shuffle."),
io.Video.Input("videos", tooltip="List of videos to shuffle."),
io.Int.Input(
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
),
],
outputs=[
io.Image.Output(
io.Video.Output(
display_name="videos",
is_output_list=True,
tooltip="Shuffled videos",
@ -1267,11 +1210,11 @@ class ShuffleVideoDatasetNode(ImageProcessingNode):
)
@classmethod
def execute(cls, images, seed):
def execute(cls, videos, seed):
seed = seed[0] if isinstance(seed, list) else seed
np.random.seed(seed % (2**32 - 1))
indices = np.random.permutation(len(images))
return io.NodeOutput([images[i] for i in indices])
indices = np.random.permutation(len(videos))
return io.NodeOutput([videos[i] for i in indices])
class ShuffleVideoTextDatasetNode(io.ComfyNode):
@ -1286,7 +1229,7 @@ class ShuffleVideoTextDatasetNode(io.ComfyNode):
is_experimental=True,
is_input_list=True,
inputs=[
io.Image.Input("videos", tooltip="List of videos to shuffle."),
io.Video.Input("videos", tooltip="List of videos to shuffle."),
io.String.Input("texts", tooltip="List of texts to shuffle."),
io.Int.Input(
"seed",
@ -1297,7 +1240,7 @@ class ShuffleVideoTextDatasetNode(io.ComfyNode):
),
],
outputs=[
io.Image.Output(
io.Video.Output(
display_name="videos",
is_output_list=True,
tooltip="Shuffled videos",
@ -1976,7 +1919,7 @@ class DatasetExtension(ComfyExtension):
AdjustContrastNode,
ShuffleDatasetNode,
ShuffleImageTextDatasetNode,
# Video processing nodes
# Video processing nodes (lazy VideoInput in/out)
VideoFrameSampleNode,
VideoTemporalCropNode,
VideoRandomTemporalCropNode,