mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
refactor(video dataset): lazy video loading and frame-selective decode
- Replace eager load_video_frames() with _decode_selected_frames() that opens the container with `with av.open(...)` (no resource leak) and decodes only the requested frame indices. - Video loader nodes now emit lazy VideoFromFile references; sampling and temporal-crop nodes operate lazily / decode only selected frames.
This commit is contained in:
parent
7775d2ab81
commit
c7d29a42ba
@ -10,7 +10,7 @@ from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.latest import ComfyExtension, io, Input, InputImpl, Types
|
||||
|
||||
|
||||
def load_and_process_images(image_files, input_dir):
|
||||
@ -46,44 +46,33 @@ def load_and_process_images(image_files, input_dir):
|
||||
VALID_VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"]
|
||||
|
||||
|
||||
def load_video_frames(video_path, max_frames=0, frame_stride=1, start_frame=0):
|
||||
"""Load video file and return frames as a tensor.
|
||||
def _decode_selected_frames(video: Input.Video, indices: list[int]) -> Input.Video:
|
||||
"""Decode only the requested frame indices from a video.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file
|
||||
max_frames: Maximum number of frames to load (0 = all)
|
||||
frame_stride: Sample every Nth frame
|
||||
start_frame: Frame index to start from
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Video frames as [T, H, W, C] float32 tensor in [0, 1]
|
||||
Opens the underlying container once, decodes frames in presentation order,
|
||||
keeps only the ones whose index is in ``indices``, and returns the result
|
||||
wrapped in a VideoFromComponents so it still satisfies the VideoInput
|
||||
contract for downstream nodes.
|
||||
"""
|
||||
container = av.open(video_path)
|
||||
stream = container.streams.video[0]
|
||||
indices_sorted = sorted(set(indices))
|
||||
max_idx = indices_sorted[-1]
|
||||
source = video.get_stream_source()
|
||||
|
||||
frames = []
|
||||
frame_idx = 0
|
||||
for frame in container.decode(stream):
|
||||
if frame_idx < start_frame:
|
||||
frame_idx += 1
|
||||
continue
|
||||
if (frame_idx - start_frame) % frame_stride != 0:
|
||||
frame_idx += 1
|
||||
continue
|
||||
if max_frames > 0 and len(frames) >= max_frames:
|
||||
break
|
||||
frames_by_idx: dict[int, torch.Tensor] = {}
|
||||
with av.open(source, mode="r") as container:
|
||||
stream = container.streams.video[0]
|
||||
wanted = set(indices_sorted)
|
||||
for frame_idx, frame in enumerate(container.decode(stream)):
|
||||
if frame_idx in wanted:
|
||||
img = frame.to_ndarray(format="rgb24")
|
||||
frames_by_idx[frame_idx] = torch.from_numpy(img.copy()).float() / 255.0
|
||||
if frame_idx >= max_idx:
|
||||
break
|
||||
|
||||
img = frame.to_ndarray(format='rgb24')
|
||||
img_tensor = torch.from_numpy(img.copy()).float() / 255.0
|
||||
frames.append(img_tensor)
|
||||
frame_idx += 1
|
||||
|
||||
container.close()
|
||||
|
||||
if not frames:
|
||||
raise ValueError(f"No frames could be extracted from {video_path}")
|
||||
|
||||
return torch.stack(frames) # [T, H, W, C]
|
||||
stacked = torch.stack([frames_by_idx[i] for i in indices])
|
||||
return InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(images=stacked, frame_rate=video.get_frame_rate())
|
||||
)
|
||||
|
||||
|
||||
class LoadImageDataSetFromFolderNode(io.ComfyNode):
|
||||
@ -211,39 +200,18 @@ class LoadVideoDataSetFromFolderNode(io.ComfyNode):
|
||||
options=folder_paths.get_input_subfolders(),
|
||||
tooltip="The folder containing video files.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"max_frames",
|
||||
default=0,
|
||||
min=0,
|
||||
max=99999,
|
||||
tooltip="Maximum frames to load per video (0 = all frames).",
|
||||
),
|
||||
io.Int.Input(
|
||||
"frame_stride",
|
||||
default=1,
|
||||
min=1,
|
||||
max=1000,
|
||||
tooltip="Sample every Nth frame (1 = every frame).",
|
||||
),
|
||||
io.Int.Input(
|
||||
"start_frame",
|
||||
default=0,
|
||||
min=0,
|
||||
max=99999,
|
||||
tooltip="Frame index to start loading from.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(
|
||||
io.Video.Output(
|
||||
display_name="videos",
|
||||
is_output_list=True,
|
||||
tooltip="List of video tensors, each [T, H, W, C].",
|
||||
tooltip="Lazy video references; frames are decoded only when needed downstream.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, folder, max_frames, frame_stride, start_frame):
|
||||
def execute(cls, folder):
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
video_files = sorted([
|
||||
f for f in os.listdir(sub_input_dir)
|
||||
@ -253,15 +221,9 @@ class LoadVideoDataSetFromFolderNode(io.ComfyNode):
|
||||
if not video_files:
|
||||
raise ValueError(f"No video files found in {sub_input_dir}")
|
||||
|
||||
output_videos = []
|
||||
for file in video_files:
|
||||
video_path = os.path.join(sub_input_dir, file)
|
||||
frames = load_video_frames(video_path, max_frames, frame_stride, start_frame)
|
||||
output_videos.append(frames)
|
||||
logging.info(f"Loaded {file}: {frames.shape[0]} frames, {frames.shape[1]}x{frames.shape[2]}")
|
||||
|
||||
logging.info(f"Loaded {len(output_videos)} videos from {sub_input_dir}")
|
||||
return io.NodeOutput(output_videos)
|
||||
videos = [InputImpl.VideoFromFile(os.path.join(sub_input_dir, f)) for f in video_files]
|
||||
logging.info(f"Loaded {len(videos)} lazy video references from {sub_input_dir}")
|
||||
return io.NodeOutput(videos)
|
||||
|
||||
|
||||
class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
|
||||
@ -278,33 +240,12 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
|
||||
options=folder_paths.get_input_subfolders(),
|
||||
tooltip="The folder containing video files and .txt captions.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"max_frames",
|
||||
default=0,
|
||||
min=0,
|
||||
max=99999,
|
||||
tooltip="Maximum frames to load per video (0 = all frames).",
|
||||
),
|
||||
io.Int.Input(
|
||||
"frame_stride",
|
||||
default=1,
|
||||
min=1,
|
||||
max=1000,
|
||||
tooltip="Sample every Nth frame (1 = every frame).",
|
||||
),
|
||||
io.Int.Input(
|
||||
"start_frame",
|
||||
default=0,
|
||||
min=0,
|
||||
max=99999,
|
||||
tooltip="Frame index to start loading from.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(
|
||||
io.Video.Output(
|
||||
display_name="videos",
|
||||
is_output_list=True,
|
||||
tooltip="List of video tensors, each [T, H, W, C].",
|
||||
tooltip="Lazy video references; frames are decoded only when needed downstream.",
|
||||
),
|
||||
io.String.Output(
|
||||
display_name="texts",
|
||||
@ -315,7 +256,7 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, folder, max_frames, frame_stride, start_frame):
|
||||
def execute(cls, folder):
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
|
||||
video_files = []
|
||||
@ -337,7 +278,6 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
|
||||
if not video_files:
|
||||
raise ValueError(f"No video files found in {sub_input_dir}")
|
||||
|
||||
# Load captions (same name as video but .txt)
|
||||
captions = []
|
||||
for vf in video_files:
|
||||
caption_path = os.path.splitext(vf)[0] + ".txt"
|
||||
@ -347,14 +287,9 @@ class LoadVideoTextDataSetFromFolderNode(io.ComfyNode):
|
||||
else:
|
||||
captions.append("")
|
||||
|
||||
# Load videos
|
||||
output_videos = []
|
||||
for vf in video_files:
|
||||
frames = load_video_frames(vf, max_frames, frame_stride, start_frame)
|
||||
output_videos.append(frames)
|
||||
|
||||
logging.info(f"Loaded {len(output_videos)} videos with captions from {sub_input_dir}")
|
||||
return io.NodeOutput(output_videos, captions)
|
||||
videos = [InputImpl.VideoFromFile(vf) for vf in video_files]
|
||||
logging.info(f"Loaded {len(videos)} lazy video references with captions from {sub_input_dir}")
|
||||
return io.NodeOutput(videos, captions)
|
||||
|
||||
|
||||
def save_images_to_folder(image_list, output_dir, prefix="image"):
|
||||
@ -1088,7 +1023,12 @@ class ShuffleImageTextDatasetNode(io.ComfyNode):
|
||||
|
||||
|
||||
class VideoFrameSampleNode(io.ComfyNode):
|
||||
"""Sample a fixed number of frames from a video using various strategies."""
|
||||
"""Sample a fixed number of frames from a video using various strategies.
|
||||
|
||||
For contiguous strategies ("head"/"tail") the result is a fully lazy
|
||||
VideoInput (no frames decoded). For non-contiguous strategies
|
||||
("uniform"/"random") only the selected indices are decoded.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -1098,7 +1038,7 @@ class VideoFrameSampleNode(io.ComfyNode):
|
||||
category="dataset/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
|
||||
io.Video.Input("video", tooltip="Input video."),
|
||||
io.Int.Input(
|
||||
"num_frames",
|
||||
default=16,
|
||||
@ -1121,20 +1061,27 @@ class VideoFrameSampleNode(io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="video", tooltip="Sampled video [N, H, W, C]."),
|
||||
io.Video.Output(display_name="video", tooltip="Sampled video."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video, num_frames, strategy, seed):
|
||||
total_frames = video.shape[0]
|
||||
total_frames = video.get_frame_count()
|
||||
num_frames = min(num_frames, total_frames)
|
||||
fps = float(video.get_frame_rate())
|
||||
|
||||
if strategy == "head":
|
||||
indices = list(range(num_frames))
|
||||
elif strategy == "tail":
|
||||
indices = list(range(total_frames - num_frames, total_frames))
|
||||
elif strategy == "uniform":
|
||||
return io.NodeOutput(
|
||||
video.as_trimmed(0.0, num_frames / fps, strict_duration=False)
|
||||
)
|
||||
if strategy == "tail":
|
||||
start_t = (total_frames - num_frames) / fps
|
||||
return io.NodeOutput(
|
||||
video.as_trimmed(start_t, num_frames / fps, strict_duration=False)
|
||||
)
|
||||
|
||||
if strategy == "uniform":
|
||||
if num_frames == 1:
|
||||
indices = [total_frames // 2]
|
||||
else:
|
||||
@ -1145,11 +1092,11 @@ class VideoFrameSampleNode(io.ComfyNode):
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy}")
|
||||
|
||||
return io.NodeOutput(video[indices])
|
||||
return io.NodeOutput(_decode_selected_frames(video, indices))
|
||||
|
||||
|
||||
class VideoTemporalCropNode(io.ComfyNode):
|
||||
"""Crop a continuous range of frames from a video."""
|
||||
"""Crop a continuous range of frames from a video (fully lazy)."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -1159,7 +1106,7 @@ class VideoTemporalCropNode(io.ComfyNode):
|
||||
category="dataset/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
|
||||
io.Video.Input("video", tooltip="Input video."),
|
||||
io.Int.Input(
|
||||
"start_frame",
|
||||
default=0,
|
||||
@ -1176,20 +1123,23 @@ class VideoTemporalCropNode(io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."),
|
||||
io.Video.Output(display_name="video", tooltip="Cropped video (lazy)."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video, start_frame, length):
|
||||
total_frames = video.shape[0]
|
||||
total_frames = video.get_frame_count()
|
||||
fps = float(video.get_frame_rate())
|
||||
start_frame = min(start_frame, max(total_frames - 1, 0))
|
||||
end_frame = min(start_frame + length, total_frames)
|
||||
return io.NodeOutput(video[start_frame:end_frame])
|
||||
length = min(length, total_frames - start_frame)
|
||||
return io.NodeOutput(
|
||||
video.as_trimmed(start_frame / fps, length / fps, strict_duration=False)
|
||||
)
|
||||
|
||||
|
||||
class VideoRandomTemporalCropNode(io.ComfyNode):
|
||||
"""Randomly crop a continuous range of frames from a video (for data augmentation)."""
|
||||
"""Randomly crop a continuous range of frames from a video (fully lazy)."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -1199,7 +1149,7 @@ class VideoRandomTemporalCropNode(io.ComfyNode):
|
||||
category="dataset/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("video", tooltip="Video tensor [T, H, W, C]."),
|
||||
io.Video.Input("video", tooltip="Input video."),
|
||||
io.Int.Input(
|
||||
"length",
|
||||
default=16,
|
||||
@ -1216,49 +1166,42 @@ class VideoRandomTemporalCropNode(io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="video", tooltip="Cropped video [length, H, W, C]."),
|
||||
io.Video.Output(display_name="video", tooltip="Cropped video (lazy)."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video, length, seed):
|
||||
total_frames = video.shape[0]
|
||||
total_frames = video.get_frame_count()
|
||||
fps = float(video.get_frame_rate())
|
||||
length = min(length, total_frames)
|
||||
max_start = total_frames - length
|
||||
rng = np.random.RandomState(seed % (2**32 - 1))
|
||||
start = rng.randint(0, max_start + 1) if max_start > 0 else 0
|
||||
return io.NodeOutput(video[start:start + length])
|
||||
return io.NodeOutput(
|
||||
video.as_trimmed(start / fps, length / fps, strict_duration=False)
|
||||
)
|
||||
|
||||
|
||||
class ShuffleVideoDatasetNode(ImageProcessingNode):
|
||||
class ShuffleVideoDatasetNode(io.ComfyNode):
|
||||
"""Randomly shuffle the order of videos in the dataset."""
|
||||
|
||||
node_id = "ShuffleVideoDataset"
|
||||
display_name = "Shuffle Video Dataset"
|
||||
description = "Randomly shuffle the order of videos in the dataset."
|
||||
is_group_process = True
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id=cls.node_id,
|
||||
display_name=cls.display_name,
|
||||
node_id="ShuffleVideoDataset",
|
||||
display_name="Shuffle Video Dataset",
|
||||
category="dataset/video",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="List of videos to shuffle."),
|
||||
io.Video.Input("videos", tooltip="List of videos to shuffle."),
|
||||
io.Int.Input(
|
||||
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(
|
||||
io.Video.Output(
|
||||
display_name="videos",
|
||||
is_output_list=True,
|
||||
tooltip="Shuffled videos",
|
||||
@ -1267,11 +1210,11 @@ class ShuffleVideoDatasetNode(ImageProcessingNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, seed):
|
||||
def execute(cls, videos, seed):
|
||||
seed = seed[0] if isinstance(seed, list) else seed
|
||||
np.random.seed(seed % (2**32 - 1))
|
||||
indices = np.random.permutation(len(images))
|
||||
return io.NodeOutput([images[i] for i in indices])
|
||||
indices = np.random.permutation(len(videos))
|
||||
return io.NodeOutput([videos[i] for i in indices])
|
||||
|
||||
|
||||
class ShuffleVideoTextDatasetNode(io.ComfyNode):
|
||||
@ -1286,7 +1229,7 @@ class ShuffleVideoTextDatasetNode(io.ComfyNode):
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
inputs=[
|
||||
io.Image.Input("videos", tooltip="List of videos to shuffle."),
|
||||
io.Video.Input("videos", tooltip="List of videos to shuffle."),
|
||||
io.String.Input("texts", tooltip="List of texts to shuffle."),
|
||||
io.Int.Input(
|
||||
"seed",
|
||||
@ -1297,7 +1240,7 @@ class ShuffleVideoTextDatasetNode(io.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(
|
||||
io.Video.Output(
|
||||
display_name="videos",
|
||||
is_output_list=True,
|
||||
tooltip="Shuffled videos",
|
||||
@ -1976,7 +1919,7 @@ class DatasetExtension(ComfyExtension):
|
||||
AdjustContrastNode,
|
||||
ShuffleDatasetNode,
|
||||
ShuffleImageTextDatasetNode,
|
||||
# Video processing nodes
|
||||
# Video processing nodes (lazy VideoInput in/out)
|
||||
VideoFrameSampleNode,
|
||||
VideoTemporalCropNode,
|
||||
VideoRandomTemporalCropNode,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user