From 3d21d2afb6eaa04aa7beae3648793329ad5ddfdb Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 13 Apr 2026 22:34:51 +1000 Subject: [PATCH] nodes_video: implement stream saver nodes --- comfy_extras/nodes_video.py | 242 +++++++++++++++++++++++++++++++++++- 1 file changed, 241 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 5c096c232..f04653d28 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -5,11 +5,162 @@ import av import torch import folder_paths import json -from typing import Optional +from typing import Callable, Optional from typing_extensions import override from fractions import Fraction from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types from comfy.cli_args import args +from comfy_execution.utils import get_executing_context +from comfy_extras.nodes_image_stream import FrameProgressTracker, drain_image_stream +from server import PromptServer + + +class SavedVideoStream(Input.ImageStream): + def __init__( + self, + stream: Input.ImageStream, + saver, + output_factory: Callable[[], tuple[str, ui.PreviewVideo | None]], + format: str, + codec, + metadata: Optional[dict], + emit_preview_on_finalize: bool = True, + preview_node_id: Optional[str] = None, + preview_display_node_id: Optional[str] = None, + ): + super().__init__() + self._stream = stream + self._saver = saver + self._output_factory = output_factory + self._path: str | None = None + self._format = Types.VideoContainer(format) + self._codec = Types.VideoCodec(codec) + self._metadata = metadata + self._preview_ui: ui.PreviewVideo | None = None + self._emit_preview_on_finalize = emit_preview_on_finalize + self._preview_node_id = preview_node_id + self._preview_display_node_id = preview_display_node_id + self._save_state = None + + def _emit_preview(self) -> None: + if not self._emit_preview_on_finalize or self._preview_ui is None: + return + + current = get_executing_context() + if current is None: + return + + server = getattr(PromptServer, "instance", None) + if server is None or server.client_id is None: + return + + server.send_sync( + "executed", + { + "node": self._preview_node_id or current.node_id, + "display_node": self._preview_display_node_id or self._preview_node_id or current.node_id, + "output": self._preview_ui.as_dict(), + "prompt_id": current.prompt_id, + }, + server.client_id, + ) + + def _discard_partial_output(self) -> None: + if self._save_state is not None: + self._save_state[0].close() + self._save_state = None + if self._path is not None and os.path.exists(self._path): + os.remove(self._path) + self._path = None + self._preview_ui = None + + def get_preview_ui(self) -> ui.PreviewVideo | None: + return self._preview_ui + + def get_dimensions(self) -> tuple[int, int]: + return self._stream.get_dimensions() + + def do_reset(self) -> None: + self._discard_partial_output() + self._stream.reset() + self._path, self._preview_ui = self._output_factory() + assert self._path is not None + open(self._path, "ab").close() + self._save_state = self._saver.save_start( + self._path, + format=self._format, + codec=self._codec, + metadata=self._metadata, + ) + + def do_pull(self, max_frames: int) -> Input.Image: + assert self._save_state is not None + chunk = self._stream.pull(max_frames) + self._saver.save_add(self._save_state[0], self._save_state[1], chunk) + if chunk.shape[0] < max_frames: + self._saver.save_finalize(*self._save_state) + self._save_state = None + self._emit_preview() + self._path = None + return chunk + + +def _build_saved_stream( + hidden, + stream: Input.ImageStream, + audio: Optional[Input.Audio], + fps: float, + filename_prefix, + format: str, + codec, + emit_preview: bool = True, +) -> SavedVideoStream: + width, height = stream.get_dimensions() + saved_metadata = None + if not args.disable_metadata: + metadata = {} + if hidden.extra_pnginfo is not None: + metadata.update(hidden.extra_pnginfo) + if hidden.prompt is not None: + metadata["prompt"] = hidden.prompt + if len(metadata) > 0: + saved_metadata = metadata + + preview_node_id = hidden.unique_id + preview_display_node_id = preview_node_id + if hidden.dynprompt is not None and preview_node_id is not None: + preview_display_node_id = hidden.dynprompt.get_display_node_id(preview_node_id) + + def output_factory() -> tuple[str, ui.PreviewVideo]: + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, + folder_paths.get_output_directory(), + width, + height, + ) + file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" + return ( + os.path.join(full_output_folder, file), + ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]), + ) + + return SavedVideoStream( + stream, + InputImpl.VideoFromComponents( + Types.VideoComponents( + images=torch.zeros((0, height, width, 3)), + audio=audio, + frame_rate=Fraction(fps), + ) + ), + output_factory, + format, + codec, + saved_metadata, + emit_preview, + preview_node_id, + preview_display_node_id, + ) class SaveWEBM(io.ComfyNode): @classmethod @@ -114,6 +265,93 @@ class SaveVideo(io.ComfyNode): return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) +class SavePassthroughVideoStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SavePassthroughVideoStream", + search_aliases=["stream to video", "save image stream", "export video stream", "passthrough video stream"], + display_name="Save+Passthrough Video Stream", + category="image/video", + essentials_category="Basics", + description="Saves frames as they pass through the input image stream.", + has_intermediate_output=True, + inputs=[ + io.ImageStream.Input("stream", tooltip="The image stream to save."), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), + io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + ], + outputs=[ + io.ImageStream.Output(display_name="passthrough"), + ], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt], + ) + + @classmethod + def execute( + cls, + stream: Input.ImageStream, + audio: Optional[Input.Audio], + fps: float, + filename_prefix, + format: str, + codec, + ) -> io.NodeOutput: + return io.NodeOutput(_build_saved_stream(cls.hidden, stream, audio, fps, filename_prefix, format, codec)) + + +class SaveVideoStream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveVideoStream", + search_aliases=["save image stream", "export video stream", "stream to video"], + display_name="Save Video Stream", + category="image/video", + essentials_category="Basics", + description="Saves an image stream by draining it directly to EOF.", + is_output_node=True, + inputs=[ + io.ImageStream.Input("stream", tooltip="The image stream to save."), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), + io.Int.Input("chunk_size", default=8, min=1, max=4096), + io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo, io.Hidden.unique_id, io.Hidden.dynprompt], + ) + + @classmethod + def execute( + cls, + stream: Input.ImageStream, + audio: Optional[Input.Audio], + fps: float, + chunk_size: int, + filename_prefix, + format: str, + codec, + ) -> io.NodeOutput: + saved_stream = _build_saved_stream( + cls.hidden, + stream, + audio, + fps, + filename_prefix, + format, + codec, + emit_preview=False, + ) + drain_image_stream(saved_stream, chunk_size, progress=FrameProgressTracker()) + return io.NodeOutput(ui=saved_stream.get_preview_ui()) + + class CreateVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -262,6 +500,8 @@ class VideoExtension(ComfyExtension): return [ SaveWEBM, SaveVideo, + SavePassthroughVideoStream, + SaveVideoStream, CreateVideo, GetVideoComponents, LoadVideo,