from __future__ import annotations import os import av import torch import folder_paths import json import logging from typing import Optional from typing_extensions import override from fractions import Fraction from comfy_api.input import AudioInput, ImageInput, VideoInput from comfy_api.input_impl import VideoFromComponents, VideoFromFile from comfy_api.util import VideoCodec, VideoComponents, VideoContainer from comfy_api.latest import ComfyExtension, io, ui from comfy.cli_args import args import comfy.utils class EncodeVideo(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="EncodeVideo", display_name="Encode Video", category="image/video", description="Encode a video using an image encoder.", inputs=[ io.Video.Input("video", tooltip="The video to be encoded."), io.Int.Input( "processing_batch_size", default=-1, min=-1, tooltip=( "Number of frames/segments to process at a time during encoding.\n" "-1 means process all at once. Smaller values reduce GPU memory usage." ), ), io.Int.Input("step_size", default=8, min=1, max=32, tooltip=( "Stride (in frames) between the start of consecutive segments.\n" "Smaller step = more overlap and smoother temporal coverage " "but higher compute cost. Larger step = faster but may miss detail." ), ), io.Vae.Input("vae", optional=True), io.ClipVision.Input("clip_vision", optional=True), ], outputs=[ io.Conditioning.Output(display_name="encoded_video"), ], ) @classmethod def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None): video = video.images if not isinstance(video, torch.Tensor): video = torch.from_numpy(video) t, *rest = video.shape # channel last if rest[-1] in (1, 3, 4) and rest[0] not in (1, 3, 4): video = video.permute(0, 3, 1, 2) t, c, h, w = video.shape device = video.device b = 1 batch_size = b * t if vae is not None and clip_vision is not None: raise ValueError("Must either have vae or clip_vision.") elif vae is None and clip_vision is None: raise ValueError("Can't have VAE and Clip Vision passed at the same time!") model = vae.first_stage_model if vae is not None else clip_vision.model vae = vae if vae is not None else clip_vision if hasattr(model, "video_encoding"): data, num_segments, output_fn = model.video_encoding(video, step_size) batch_size = b * num_segments else: data = video.view(batch_size, c, h, w) output_fn = lambda x: x.view(b, t, -1) if processing_batch_size != -1: batch_size = processing_batch_size outputs = None total = data.shape[0] pbar = comfy.utils.ProgressBar(total/batch_size) model_dtype = next(model.parameters()).dtype with torch.inference_mode(): for i in range(0, total, batch_size): chunk = data[i : i + batch_size].to(device, non_blocking = True) chunk = chunk.to(model_dtype) if hasattr(vae, "encode"): try: if chunk.ndim > 5: raise ValueError("chunk.ndim > 5") chunk = chunk.movedim(1, -1) out = vae.encode(chunk) except Exception: out = model.encode(chunk) else: chunk = chunk.movedim(1, -1) out = vae.encode_image(chunk.to(torch.uint8), crop=False, resize_mode="bilinear") out = out["image_embeds"] out_cpu = out.cpu() if outputs is None: full_shape = (total, *out_cpu.shape[1:]) # should be the offload device outputs = torch.empty(full_shape, dtype=out_cpu.dtype, pin_memory=True) chunk_len = out_cpu.shape[0] outputs[i : i + chunk_len].copy_(out_cpu) del out, chunk, out_cpu torch.cuda.empty_cache() pbar.update(1) return io.NodeOutput(output_fn(outputs)) class ResampleVideo(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="ResampleVideo", display_name="Resample Video", category="image/video", inputs = [ io.Video.Input("video"), io.Int.Input("target_fps", min=1, default=25) ], outputs=[io.Video.Output(display_name="video")] ) @classmethod def execute(cls, video, target_fps: int): # doesn't support upsampling video_components = video.get_components() with av.open(video.get_stream_source(), mode="r") as container: stream = container.streams.video[0] frames = [] src_rate = stream.average_rate or stream.guessed_rate src_fps = float(src_rate) if src_rate else None if src_fps is None: logging.warning("src_fps for video resampling is None.") # yield original frames if asked for upsampling if target_fps > src_fps: return io.NodeOutput(video_components) stream.thread_type = "AUTO" next_time = 0.0 step = 1.0 / target_fps for packet in container.demux(stream): for frame in packet.decode(): if frame.time is None: continue t = frame.time while t >= next_time: arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() frames.append(arr) next_time += step new_components = VideoComponents( images=torch.stack(frames), audio=video_components.audio, frame_rate=Fraction(target_fps, 1), metadata=video_components.metadata, ) return io.NodeOutput(new_components) class SaveWEBM(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="SaveWEBM", category="image/video", is_experimental=True, inputs=[ io.Image.Input("images"), io.String.Input("filename_prefix", default="ComfyUI"), io.Combo.Input("codec", options=["vp9", "av1"]), io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."), ], outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput: full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0] ) file = f"{filename}_{counter:05}_.webm" container = av.open(os.path.join(full_output_folder, file), mode="w") if cls.hidden.prompt is not None: container.metadata["prompt"] = json.dumps(cls.hidden.prompt) if cls.hidden.extra_pnginfo is not None: for x in cls.hidden.extra_pnginfo: container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"} stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) stream.width = images.shape[-2] stream.height = images.shape[-3] stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p" stream.bit_rate = 0 stream.options = {'crf': str(crf)} if codec == "av1": stream.options["preset"] = "6" for frame in images: frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24") for packet in stream.encode(frame): container.mux(packet) container.mux(stream.encode()) container.close() return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) class SaveVideo(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="SaveVideo", display_name="Save Video", category="image/video", description="Saves the input images to your ComfyUI output directory.", inputs=[ io.Video.Input("video", tooltip="The video to save."), io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), ], outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, folder_paths.get_output_directory(), width, height ) saved_metadata = None if not args.disable_metadata: metadata = {} if cls.hidden.extra_pnginfo is not None: metadata.update(cls.hidden.extra_pnginfo) if cls.hidden.prompt is not None: metadata["prompt"] = cls.hidden.prompt if len(metadata) > 0: saved_metadata = metadata file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), format=format, codec=codec, metadata=saved_metadata ) return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) class CreateVideo(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="CreateVideo", display_name="Create Video", category="image/video", description="Create a video from images.", inputs=[ io.Image.Input("images", tooltip="The images to create a video from."), io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), ], outputs=[ io.Video.Output(), ], ) @classmethod def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput: return io.NodeOutput( VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) ) class GetVideoComponents(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="GetVideoComponents", display_name="Get Video Components", category="image/video", description="Extracts all components from a video: frames, audio, and framerate.", inputs=[ io.Video.Input("video", tooltip="The video to extract components from."), ], outputs=[ io.Image.Output(display_name="images"), io.Audio.Output(display_name="audio"), io.Float.Output(display_name="fps"), ], ) @classmethod def execute(cls, video: VideoInput) -> io.NodeOutput: components = video.get_components() return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) class LoadVideo(io.ComfyNode): @classmethod def define_schema(cls): input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] files = folder_paths.filter_files_content_types(files, ["video"]) return io.Schema( node_id="LoadVideo", display_name="Load Video", category="image/video", inputs=[ io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video), ], outputs=[ io.Video.Output(), ], ) @classmethod def execute(cls, file) -> io.NodeOutput: video_path = folder_paths.get_annotated_filepath(file) return io.NodeOutput(VideoFromFile(video_path)) @classmethod def fingerprint_inputs(s, file): video_path = folder_paths.get_annotated_filepath(file) mod_time = os.path.getmtime(video_path) # Instead of hashing the file, we can just use the modification time to avoid # rehashing large files. return mod_time @classmethod def validate_inputs(s, file): if not folder_paths.exists_annotated_filepath(file): return "Invalid video file: {}".format(file) return True class VideoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ SaveWEBM, SaveVideo, CreateVideo, GetVideoComponents, LoadVideo, EncodeVideo, ResampleVideo, ] async def comfy_entrypoint() -> VideoExtension: return VideoExtension()